212 lines
5.1 KiB
Go
212 lines
5.1 KiB
Go
package internal
|
|
|
|
import (
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
type SVGData struct {
|
|
Name string
|
|
SVG string
|
|
}
|
|
|
|
func GetSvgsFromPath(path string) ([]SVGData, error) {
|
|
files, err := os.ReadDir(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
svgs := make([]SVGData, 0)
|
|
for _, file := range files {
|
|
if file.IsDir() || !strings.HasSuffix(file.Name(), ".svg") {
|
|
continue
|
|
}
|
|
svgFile, err := os.ReadFile(filepath.Join(path, file.Name()))
|
|
if err != nil {
|
|
continue
|
|
}
|
|
svg := SVGData{
|
|
Name: file.Name(),
|
|
SVG: string(svgFile),
|
|
}
|
|
svgs = append(svgs, svg)
|
|
}
|
|
return svgs, nil
|
|
}
|
|
|
|
type Response struct {
|
|
Data []struct {
|
|
URL string `json:"url"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
func GenerateIcon(prompt string) (string, error) {
|
|
client := &http.Client{}
|
|
recraftBody := fmt.Sprintf(`{
|
|
"model": "recraftv2",
|
|
"prompt": "%s",
|
|
"style": "icon"
|
|
}`, prompt)
|
|
req, err := http.NewRequest("POST", "https://external.api.recraft.ai/v1/images/generations", strings.NewReader(recraftBody))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Add("Content-Type", "application/json")
|
|
req.Header.Add("Authorization", "Bearer i153LYqEDGOipN2LCCvRBCV2uwJWMH7nZzxX8RKiZwfVTGgAgXBMDVNgvEYk0YtJ")
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("bad status: %s", resp.Status)
|
|
}
|
|
// Step 2: Read the response body into memory
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var data Response
|
|
if err := json.Unmarshal(body, &data); err != nil {
|
|
return "", fmt.Errorf("unable to unmarshal json: %v", err)
|
|
}
|
|
if len(data.Data) == 0 {
|
|
return "", fmt.Errorf("no data found in response")
|
|
}
|
|
return data.Data[0].URL, nil
|
|
}
|
|
|
|
func DownloadFile(filepath string, url string) error {
|
|
// out, err := os.Create(filepath)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// defer out.Close()
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("bad status: %s", resp.Status)
|
|
}
|
|
svgData, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cleanSvg := CleanHeader(string(svgData))
|
|
if err = os.WriteFile(filepath, []byte(cleanSvg), 0700); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
const (
|
|
lowercase = "abcdefghijklmnopqrstuvwxyz"
|
|
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
digits = "0123456789"
|
|
|
|
allChars = lowercase + uppercase + digits + "-"
|
|
)
|
|
|
|
var Allowed map[rune]struct{} = func() map[rune]struct{} {
|
|
allowed := make(map[rune]struct{})
|
|
for _, ch := range allChars {
|
|
allowed[ch] = struct{}{}
|
|
}
|
|
return allowed
|
|
}()
|
|
|
|
func SanatizePrompt(prompt string) string {
|
|
var sb strings.Builder
|
|
// Replace disallowed characters with '-'
|
|
for _, ch := range prompt {
|
|
if _, ok := Allowed[ch]; ok {
|
|
sb.WriteRune(ch)
|
|
} else {
|
|
sb.WriteRune('-')
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func FixIconHeader() {
|
|
cliCmd := flag.NewFlagSet("fix", flag.ExitOnError)
|
|
svgPath := cliCmd.String("svg-path", "", "Path to folder with svgs")
|
|
outputPath := cliCmd.String("output-path", "", "Path to put updated svgs")
|
|
if err := cliCmd.Parse(os.Args[2:]); err != nil {
|
|
log.Fatalf("Failed to parse flags: %v", err)
|
|
}
|
|
svgs, err := GetSvgsFromPath(*svgPath)
|
|
if err != nil {
|
|
log.Fatalf("fatal error reading svg dir %v", err)
|
|
}
|
|
if err = os.MkdirAll(*outputPath, 0700); err != nil {
|
|
log.Fatalf("fatal error making output dir %v", err)
|
|
}
|
|
for _, svg := range svgs {
|
|
cleanSvg := CleanHeader(svg.SVG)
|
|
if err = os.WriteFile(filepath.Join(*outputPath, svg.Name), []byte(cleanSvg), 0700); err != nil {
|
|
log.Printf("error writing %s; %v ", svg.Name, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func FileExists(path string) bool {
|
|
_, err := os.Stat(path)
|
|
return err == nil
|
|
}
|
|
|
|
func DownloadIconCli() {
|
|
cliCmd := flag.NewFlagSet("download", flag.ExitOnError)
|
|
outputPath := cliCmd.String("output-path", "", "path to output folder")
|
|
prompt := cliCmd.String("prompt", "", "recraft icon prompt")
|
|
if err := cliCmd.Parse(os.Args[2:]); err != nil {
|
|
log.Fatalf("Failed to parse flags: %v", err)
|
|
}
|
|
if _, err := DownloadIcon(*prompt, *outputPath); err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
}
|
|
|
|
func DownloadIcon(prompt string, outputPath string) (string, error) {
|
|
sanatizedPrompt := SanatizePrompt(prompt)
|
|
svgFilePath := FindNoneCollisionFileName(sanatizedPrompt, "svg", outputPath)
|
|
url, err := GenerateIcon(prompt)
|
|
if err != nil {
|
|
log.Printf("error generating icon: %v\n", err)
|
|
return "", err
|
|
}
|
|
if err := DownloadFile(svgFilePath, url); err != nil {
|
|
log.Printf("error downloading %s : %v\n", url, err)
|
|
return "", err
|
|
}
|
|
log.Println("success")
|
|
return svgFilePath, nil
|
|
}
|
|
|
|
func FindNoneCollisionFileName(fileName, fileExt, path string) string {
|
|
|
|
filePath := filepath.Join(path, fileName+"."+fileExt)
|
|
for fileCount := 1; FileExists(filePath); fileCount += 1 {
|
|
log.Println("file exists: ", filePath)
|
|
filePath = filepath.Join(
|
|
path,
|
|
fmt.Sprintf("%s (%d).%s", fileName, fileCount, fileExt),
|
|
)
|
|
}
|
|
return filePath
|
|
}
|
|
|
|
func CleanHeader(input string) string {
|
|
re := regexp.MustCompile(`\s*(style|width|height)="[^"]*"`)
|
|
return re.ReplaceAllString(input, "")
|
|
}
|