|
package openai |
|
|
|
import ( |
|
"bufio" |
|
"encoding/base64" |
|
"encoding/json" |
|
"fmt" |
|
"io" |
|
"net/http" |
|
"os" |
|
"path/filepath" |
|
"strconv" |
|
"strings" |
|
"time" |
|
|
|
"github.com/google/uuid" |
|
"github.com/mudler/LocalAI/core/config" |
|
"github.com/mudler/LocalAI/core/schema" |
|
|
|
"github.com/mudler/LocalAI/core/backend" |
|
|
|
"github.com/gofiber/fiber/v2" |
|
model "github.com/mudler/LocalAI/pkg/model" |
|
"github.com/rs/zerolog/log" |
|
) |
|
|
|
func downloadFile(url string) (string, error) { |
|
|
|
resp, err := http.Get(url) |
|
if err != nil { |
|
return "", err |
|
} |
|
defer resp.Body.Close() |
|
|
|
|
|
out, err := os.CreateTemp("", "image") |
|
if err != nil { |
|
return "", err |
|
} |
|
defer out.Close() |
|
|
|
|
|
_, err = io.Copy(out, resp.Body) |
|
return out.Name(), err |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { |
|
return func(c *fiber.Ctx) error { |
|
m, input, err := readRequest(c, cl, ml, appConfig, false) |
|
if err != nil { |
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
} |
|
|
|
if m == "" { |
|
m = model.StableDiffusionBackend |
|
} |
|
log.Debug().Msgf("Loading model: %+v", m) |
|
|
|
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false) |
|
if err != nil { |
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
} |
|
|
|
src := "" |
|
if input.File != "" { |
|
|
|
fileData := []byte{} |
|
|
|
|
|
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { |
|
out, err := downloadFile(input.File) |
|
if err != nil { |
|
return fmt.Errorf("failed downloading file:%w", err) |
|
} |
|
defer os.RemoveAll(out) |
|
|
|
fileData, err = os.ReadFile(out) |
|
if err != nil { |
|
return fmt.Errorf("failed reading file:%w", err) |
|
} |
|
|
|
} else { |
|
|
|
|
|
fileData, err = base64.StdEncoding.DecodeString(input.File) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
|
|
|
|
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") |
|
if err != nil { |
|
return err |
|
} |
|
|
|
writer := bufio.NewWriter(outputFile) |
|
_, err = writer.Write(fileData) |
|
if err != nil { |
|
outputFile.Close() |
|
return err |
|
} |
|
outputFile.Close() |
|
src = outputFile.Name() |
|
defer os.RemoveAll(src) |
|
} |
|
|
|
log.Debug().Msgf("Parameter Config: %+v", config) |
|
|
|
switch config.Backend { |
|
case "stablediffusion": |
|
config.Backend = model.StableDiffusionBackend |
|
case "tinydream": |
|
config.Backend = model.TinyDreamBackend |
|
case "": |
|
config.Backend = model.StableDiffusionBackend |
|
} |
|
|
|
sizeParts := strings.Split(input.Size, "x") |
|
if len(sizeParts) != 2 { |
|
return fmt.Errorf("invalid value for 'size'") |
|
} |
|
width, err := strconv.Atoi(sizeParts[0]) |
|
if err != nil { |
|
return fmt.Errorf("invalid value for 'size'") |
|
} |
|
height, err := strconv.Atoi(sizeParts[1]) |
|
if err != nil { |
|
return fmt.Errorf("invalid value for 'size'") |
|
} |
|
|
|
b64JSON := config.ResponseFormat == "b64_json" |
|
|
|
|
|
var result []schema.Item |
|
for _, i := range config.PromptStrings { |
|
n := input.N |
|
if input.N == 0 { |
|
n = 1 |
|
} |
|
for j := 0; j < n; j++ { |
|
prompts := strings.Split(i, "|") |
|
positive_prompt := prompts[0] |
|
negative_prompt := "" |
|
if len(prompts) > 1 { |
|
negative_prompt = prompts[1] |
|
} |
|
|
|
mode := 0 |
|
step := config.Step |
|
if step == 0 { |
|
step = 15 |
|
} |
|
|
|
if input.Mode != 0 { |
|
mode = input.Mode |
|
} |
|
|
|
if input.Step != 0 { |
|
step = input.Step |
|
} |
|
|
|
tempDir := "" |
|
if !b64JSON { |
|
tempDir = appConfig.ImageDir |
|
} |
|
|
|
outputFile, err := os.CreateTemp(tempDir, "b64") |
|
if err != nil { |
|
return err |
|
} |
|
outputFile.Close() |
|
output := outputFile.Name() + ".png" |
|
|
|
err = os.Rename(outputFile.Name(), output) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
baseURL := c.BaseURL() |
|
|
|
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) |
|
if err != nil { |
|
return err |
|
} |
|
if err := fn(); err != nil { |
|
return err |
|
} |
|
|
|
item := &schema.Item{} |
|
|
|
if b64JSON { |
|
defer os.RemoveAll(output) |
|
data, err := os.ReadFile(output) |
|
if err != nil { |
|
return err |
|
} |
|
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
|
} else { |
|
base := filepath.Base(output) |
|
item.URL = baseURL + "/generated-images/" + base |
|
} |
|
|
|
result = append(result, *item) |
|
} |
|
} |
|
|
|
id := uuid.New().String() |
|
created := int(time.Now().Unix()) |
|
resp := &schema.OpenAIResponse{ |
|
ID: id, |
|
Created: created, |
|
Data: result, |
|
} |
|
|
|
jsonResult, _ := json.Marshal(resp) |
|
log.Debug().Msgf("Response: %s", jsonResult) |
|
|
|
|
|
return c.JSON(resp) |
|
} |
|
} |
|
|