mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
f895d06605
* fix(defaults): set better defaults for inferencing This changeset aim to have better defaults and to properly detect when no inference settings are provided with the model. If not specified, we defaults to mirostat sampling, and offload all the GPU layers (if a GPU is detected). Related to https://github.com/mudler/LocalAI/issues/1373 and https://github.com/mudler/LocalAI/issues/1723 * Adapt tests * Also pre-initialize default seed
240 lines
5.2 KiB
Go
240 lines
5.2 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-skynet/LocalAI/core/config"
|
|
"github.com/go-skynet/LocalAI/core/schema"
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/go-skynet/LocalAI/core/backend"
|
|
|
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
func downloadFile(url string) (string, error) {
|
|
// Get the data
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Create the file
|
|
out, err := os.CreateTemp("", "image")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer out.Close()
|
|
|
|
// Write the body to file
|
|
_, err = io.Copy(out, resp.Body)
|
|
return out.Name(), err
|
|
}
|
|
|
|
// https://platform.openai.com/docs/api-reference/images/create
|
|
|
|
/*
|
|
*
|
|
|
|
curl http://localhost:8080/v1/images/generations \
|
|
-H "Content-Type: application/json" \
|
|
-d '{
|
|
"prompt": "A cute baby sea otter",
|
|
"n": 1,
|
|
"size": "512x512"
|
|
}'
|
|
|
|
*
|
|
*/
|
|
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
|
return func(c *fiber.Ctx) error {
|
|
m, input, err := readRequest(c, ml, appConfig, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
}
|
|
|
|
if m == "" {
|
|
m = model.StableDiffusionBackend
|
|
}
|
|
log.Debug().Msgf("Loading model: %+v", m)
|
|
|
|
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
}
|
|
|
|
src := ""
|
|
if input.File != "" {
|
|
|
|
fileData := []byte{}
|
|
// check if input.File is an URL, if so download it and save it
|
|
// to a temporary file
|
|
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
|
out, err := downloadFile(input.File)
|
|
if err != nil {
|
|
return fmt.Errorf("failed downloading file:%w", err)
|
|
}
|
|
defer os.RemoveAll(out)
|
|
|
|
fileData, err = os.ReadFile(out)
|
|
if err != nil {
|
|
return fmt.Errorf("failed reading file:%w", err)
|
|
}
|
|
|
|
} else {
|
|
// base 64 decode the file and write it somewhere
|
|
// that we will cleanup
|
|
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Create a temporary file
|
|
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// write the base64 result
|
|
writer := bufio.NewWriter(outputFile)
|
|
_, err = writer.Write(fileData)
|
|
if err != nil {
|
|
outputFile.Close()
|
|
return err
|
|
}
|
|
outputFile.Close()
|
|
src = outputFile.Name()
|
|
defer os.RemoveAll(src)
|
|
}
|
|
|
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
|
|
switch config.Backend {
|
|
case "stablediffusion":
|
|
config.Backend = model.StableDiffusionBackend
|
|
case "tinydream":
|
|
config.Backend = model.TinyDreamBackend
|
|
case "":
|
|
config.Backend = model.StableDiffusionBackend
|
|
}
|
|
|
|
sizeParts := strings.Split(input.Size, "x")
|
|
if len(sizeParts) != 2 {
|
|
return fmt.Errorf("invalid value for 'size'")
|
|
}
|
|
width, err := strconv.Atoi(sizeParts[0])
|
|
if err != nil {
|
|
return fmt.Errorf("invalid value for 'size'")
|
|
}
|
|
height, err := strconv.Atoi(sizeParts[1])
|
|
if err != nil {
|
|
return fmt.Errorf("invalid value for 'size'")
|
|
}
|
|
|
|
b64JSON := false
|
|
if input.ResponseFormat.Type == "b64_json" {
|
|
b64JSON = true
|
|
}
|
|
// src and clip_skip
|
|
var result []schema.Item
|
|
for _, i := range config.PromptStrings {
|
|
n := input.N
|
|
if input.N == 0 {
|
|
n = 1
|
|
}
|
|
for j := 0; j < n; j++ {
|
|
prompts := strings.Split(i, "|")
|
|
positive_prompt := prompts[0]
|
|
negative_prompt := ""
|
|
if len(prompts) > 1 {
|
|
negative_prompt = prompts[1]
|
|
}
|
|
|
|
mode := 0
|
|
step := config.Step
|
|
if step == 0 {
|
|
step = 15
|
|
}
|
|
|
|
if input.Mode != 0 {
|
|
mode = input.Mode
|
|
}
|
|
|
|
if input.Step != 0 {
|
|
step = input.Step
|
|
}
|
|
|
|
tempDir := ""
|
|
if !b64JSON {
|
|
tempDir = appConfig.ImageDir
|
|
}
|
|
// Create a temporary file
|
|
outputFile, err := os.CreateTemp(tempDir, "b64")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
outputFile.Close()
|
|
output := outputFile.Name() + ".png"
|
|
// Rename the temporary file
|
|
err = os.Rename(outputFile.Name(), output)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
baseURL := c.BaseURL()
|
|
|
|
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := fn(); err != nil {
|
|
return err
|
|
}
|
|
|
|
item := &schema.Item{}
|
|
|
|
if b64JSON {
|
|
defer os.RemoveAll(output)
|
|
data, err := os.ReadFile(output)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
|
} else {
|
|
base := filepath.Base(output)
|
|
item.URL = baseURL + "/generated-images/" + base
|
|
}
|
|
|
|
result = append(result, *item)
|
|
}
|
|
}
|
|
|
|
id := uuid.New().String()
|
|
created := int(time.Now().Unix())
|
|
resp := &schema.OpenAIResponse{
|
|
ID: id,
|
|
Created: created,
|
|
Data: result,
|
|
}
|
|
|
|
jsonResult, _ := json.Marshal(resp)
|
|
log.Debug().Msgf("Response: %s", jsonResult)
|
|
|
|
// Return the prediction in the response body
|
|
return c.JSON(resp)
|
|
}
|
|
}
|