This commit is contained in:
Dave 2024-05-08 00:48:40 +00:00 committed by GitHub
commit 60d2f24fcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 912 additions and 677 deletions

View File

@ -31,7 +31,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))
if model == nil {
return fmt.Errorf("could not load model")
return fmt.Errorf("rwkv could not load model")
}
llm.rwkv = model
return nil

View File

@ -1,6 +1,7 @@
package core
import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
@ -17,13 +18,15 @@ type Application struct {
// Core Low-Level Services
BackendConfigLoader *config.BackendConfigLoader
ModelLoader *model.ModelLoader
StoresLoader *model.ModelLoader
// Backend Services
// EmbeddingsBackendService *backend.EmbeddingsBackendService
EmbeddingsBackendService *backend.EmbeddingsBackendService
// ImageGenerationBackendService *backend.ImageGenerationBackendService
// LLMBackendService *backend.LLMBackendService
// TranscriptionBackendService *backend.TranscriptionBackendService
// TextToSpeechBackendService *backend.TextToSpeechBackendService
TextToSpeechBackendService *backend.TextToSpeechBackendService
// RerankBackendService *backend.RerankBackendService
// LocalAI System Services
BackendMonitorService *services.BackendMonitorService
@ -31,6 +34,7 @@ type Application struct {
ListModelsService *services.ListModelsService
LocalAIMetricsService *services.LocalAIMetricsService
// OpenAIService *services.OpenAIService
}
// TODO [NEXT PR?]: Break up ApplicationConfig.

View File

@ -2,14 +2,108 @@ package backend
import (
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
)
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
type EmbeddingsBackendService struct {
ml *model.ModelLoader
bcl *config.BackendConfigLoader
appConfig *config.ApplicationConfig
}
func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService {
return &EmbeddingsBackendService{
ml: ml,
bcl: bcl,
appConfig: appConfig,
}
}
func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) *concurrency.JobResult[*schema.OpenAIRequest, *schema.OpenAIResponse] {
jr, wjr := concurrency.NewJobResult[*schema.OpenAIRequest, *schema.OpenAIResponse](request)
go func(wjr *concurrency.WritableJobResult[*schema.OpenAIRequest, *schema.OpenAIResponse]) {
id := uuid.New().String()
created := int(time.Now().Unix())
request = *wjr.Request // TODO is needed?
bc, err := ebs.bcl.LoadBackendConfigFileByName(request.Model, ebs.appConfig.ModelPath,
config.LoadOptionDebug(ebs.appConfig.Debug),
config.LoadOptionThreads(ebs.appConfig.Threads),
config.LoadOptionContextSize(ebs.appConfig.ContextSize),
config.LoadOptionF16(ebs.appConfig.F16),
)
if err != nil {
log.Error().Err(err).Str("modelPath", ebs.appConfig.ModelPath).Msg("unable to load backend config")
wjr.SetResult(nil, err)
return
}
// Set the parameters for the language model prediction
bc.UpdateFromOpenAIRequest(request)
items := []schema.Item{}
for i, s := range bc.InputToken {
// get the model function to call for the result
embedFn, err := ebs.modelEmbedding("", s, *bc)
if err != nil {
log.Error().Err(err).Ints("numeric tokens", s).Msg("error during modelEmbedding")
wjr.SetResult(nil, err)
return
}
embeddings, err := embedFn()
if err != nil {
log.Error().Err(err).Ints("numeric tokens", s).Msg("error during embedFn")
wjr.SetResult(nil, err)
return
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range bc.InputStrings {
// get the model function to call for the result
embedFn, err := ebs.modelEmbedding(s, []int{}, *bc)
if err != nil {
log.Error().Err(err).Str("string tokens", s).Msg("error during modelEmbedding")
wjr.SetResult(nil, err)
return
}
embeddings, err := embedFn()
if err != nil {
log.Error().Err(err).Str("string tokens", s).Msg("error during embedFn")
wjr.SetResult(nil, err)
return
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
wjr.SetResult(resp, nil)
}(wjr)
return jr
}
func (ebs *EmbeddingsBackendService) modelEmbedding(s string, tokens []int, backendConfig config.BackendConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model
grpcOpts := gRPCModelOpts(backendConfig)
@ -17,19 +111,19 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
var inferenceModel interface{}
var err error
opts := modelOpts(backendConfig, appConfig, []model.Option{
opts := modelOpts(backendConfig, ebs.appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(*backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithAssetDir(ebs.appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithContext(ebs.appConfig.Context),
})
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
inferenceModel, err = ebs.ml.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
inferenceModel, err = ebs.ml.BackendLoader(opts...)
}
if err != nil {
return nil, err
@ -39,7 +133,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
switch model := inferenceModel.(type) {
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
predictOptions := gRPCPredictOpts(backendConfig, ebs.appConfig.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}
@ -48,7 +142,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
}
predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(appConfig.Context, predictOptions)
res, err := model.Embeddings(ebs.appConfig.Context, predictOptions)
if err != nil {
return nil, err
}
@ -57,7 +151,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
}
predictOptions.Embeddings = s
res, err := model.Embeddings(appConfig.Context, predictOptions)
res, err := model.Embeddings(ebs.appConfig.Context, predictOptions)
if err != nil {
return nil, err
}

View File

@ -7,12 +7,123 @@ import (
"path/filepath"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/rs/zerolog/log"
"github.com/go-skynet/LocalAI/pkg/concurrency"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
type TextToSpeechBackendService struct {
ml *model.ModelLoader
bcl *config.BackendConfigLoader
appConfig *config.ApplicationConfig
}
func NewTextToSpeechBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TextToSpeechBackendService {
return &TextToSpeechBackendService{
ml: ml,
bcl: bcl,
appConfig: appConfig,
}
}
func (ttsbs *TextToSpeechBackendService) TextToAudioFile(request *schema.TTSRequest) *concurrency.JobResult[*schema.TTSRequest, string] {
jr, wjr := concurrency.NewJobResult[*schema.TTSRequest, string](request)
go func(wjr *concurrency.WritableJobResult[*schema.TTSRequest, string]) {
bc, err := ttsbs.bcl.LoadBackendConfigFileByName(request.Model, ttsbs.appConfig.ModelPath,
config.LoadOptionDebug(ttsbs.appConfig.Debug),
config.LoadOptionThreads(ttsbs.appConfig.Threads),
config.LoadOptionContextSize(ttsbs.appConfig.ContextSize),
config.LoadOptionF16(ttsbs.appConfig.F16),
)
if err != nil || bc == nil {
log.Error().Err(err).Str("modelName", request.Model).Str("modelPath", ttsbs.appConfig.ModelPath).Msg("unable to load backend config")
wjr.SetResult("", err)
return
}
if request.Backend != "" { // Allow users to specify a backend to use that overrides config.
bc.Backend = request.Backend
}
// TODO consider merging the below function in, but leave it seperated for diff reasons in the first PR
dst, err := ttsbs.modelTTS(request.Backend, request.Input, bc.Model, request.Voice, *bc)
log.Debug().Str("dst", dst).Err(err).Msg("modelTTS result in goroutine")
wjr.SetResult(dst, err)
}(wjr)
return jr
}
func (ttsbs *TextToSpeechBackendService) modelTTS(backend, text, modelFile, voice string, backendConfig config.BackendConfig) (string, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, ttsbs.appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(ttsbs.appConfig.Context),
model.WithAssetDir(ttsbs.appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
ttsModel, err := ttsbs.ml.BackendLoader(opts...)
if err != nil {
return "", err
}
if ttsModel == nil {
return "", fmt.Errorf("could not load piper model")
}
if ttsbs.appConfig.AudioDir == "" {
return "", fmt.Errorf("ApplicationConfig.AudioDir not set, cannot continue")
}
// Shouldn't be needed anymore. Consider removing later
if err := os.MkdirAll(ttsbs.appConfig.AudioDir, 0750); err != nil {
return "", fmt.Errorf("failed` creating audio directory: %s", err)
}
fileName := generateUniqueFileName(ttsbs.appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(ttsbs.appConfig.AudioDir, fileName)
log.Debug().Str("filePath", filePath).Msg("computed output filePath")
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(ttsbs.appConfig.ModelPath, modelFile)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, ttsbs.appConfig.ModelPath); err != nil {
return "", err
}
modelPath = mp
} else {
modelPath = modelFile
}
}
_, err = ttsModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
})
return filePath, err
}
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
@ -28,62 +139,3 @@ func generateUniqueFileName(dir, baseName, ext string) string {
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}
grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
ttsModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if ttsModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, modelFile)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = modelFile
}
}
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
})
return filePath, res, err
}

View File

@ -126,16 +126,16 @@ func (r *RunCMD) Run(ctx *Context) error {
}
if r.PreloadBackendOnly {
_, _, _, err := startup.Startup(opts...)
_, err := startup.Startup(opts...)
return err
}
cl, ml, options, err := startup.Startup(opts...)
app, err := startup.Startup(opts...)
if err != nil {
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
appHTTP, err := http.App(cl, ml, options)
appHTTP, err := http.App(app)
if err != nil {
log.Error().Err(err).Msg("error during HTTP App construction")
return err

View File

@ -9,6 +9,7 @@ import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@ -48,20 +49,32 @@ func (t *TTSCMD) Run(ctx *Context) error {
}
}()
options := config.BackendConfig{}
options.SetDefaults()
request := &schema.TTSRequest{
Backend: t.Backend,
Input: text,
Model: t.Model,
Voice: t.Voice,
}
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options)
ttsbs := backend.NewTextToSpeechBackendService(ml, config.NewBackendConfigLoader(), opts)
jr := ttsbs.TextToAudioFile(request)
filePathPtr, err := jr.Wait()
if err != nil {
return err
}
if filePathPtr == nil {
err := fmt.Errorf("recieved a nil filepath from TextToAudioFile")
log.Error().Err(err).Msg("tts cli error")
return err
}
if outputFile != "" {
if err := os.Rename(filePath, outputFile); err != nil {
if err := os.Rename(*filePathPtr, outputFile); err != nil {
return err
}
fmt.Printf("Generate file %s\n", outputFile)
} else {
fmt.Printf("Generate file %s\n", filePath)
fmt.Printf("Generate file %s\n", *filePathPtr)
}
return nil
}

View File

@ -1,12 +1,15 @@
package config
import (
"encoding/json"
"fmt"
"os"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/functions"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
const (
@ -332,3 +335,203 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
cfg.Debug = &trueV
}
}
func (config *BackendConfig) UpdateFromOpenAIRequest(input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
log.Error().Err(err).Msg("Failed encoding image")
}
}
}
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}

View File

@ -6,16 +6,16 @@ import (
"net/http"
"strings"
"github.com/go-skynet/LocalAI/core"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/http/routes"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
@ -24,7 +24,6 @@ import (
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
// swagger handler
"github.com/rs/zerolog/log"
)
@ -64,11 +63,11 @@ var embedDirStatic embed.FS
// @in header
// @name Authorization
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
func App(application *core.Application) (*fiber.App, error) {
// Return errors as JSON responses
app := fiber.New(fiber.Config{
Views: renderEngine(),
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
BodyLimit: application.ApplicationConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
@ -109,7 +108,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
// Default middleware config
if !appConfig.Debug {
if !application.ApplicationConfig.Debug {
app.Use(recover.New())
}
@ -127,11 +126,11 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(appConfig.ApiKeys) == 0 {
if len(application.ApplicationConfig.ApiKeys) == 0 {
return c.Next()
}
if len(appConfig.ApiKeys) == 0 {
if len(application.ApplicationConfig.ApiKeys) == 0 {
return c.Next()
}
@ -147,7 +146,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
}
apiKey := authHeaderParts[1]
for _, key := range appConfig.ApiKeys {
for _, key := range application.ApplicationConfig.ApiKeys {
if apiKey == key {
return c.Next()
}
@ -156,32 +155,36 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
if appConfig.CORS {
if application.ApplicationConfig.CORS {
var c func(ctx *fiber.Ctx) error
if appConfig.CORSAllowOrigins == "" {
if application.ApplicationConfig.CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig.CORSAllowOrigins})
}
app.Use(c)
}
// Load config jsons
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
utils.LoadConfig(application.ApplicationConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
utils.LoadConfig(application.ApplicationConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
utils.LoadConfig(application.ApplicationConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
galleryService := services.NewGalleryService(appConfig.ModelPath)
galleryService.Start(appConfig.Context, cl)
// Create the Fiber Content Extractor that the endpoints will use instead of the full modelLoader
fce := ctx.NewFiberContentExtractor(application.ModelLoader, application.ApplicationConfig)
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
if !appConfig.DisableWebUI {
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
// Register all routes - TODO: enhance for partial registration?
// For the "large" register function, it seems to make sense to pass application directly and allow them to sort out their dependencies.
// However, for particularly simple routes, passing dependencies directly may be more clean? Try both and experiment!
routes.RegisterElevenLabsRoutes(app, application.TextToSpeechBackendService, fce, auth)
routes.RegisterLocalAIRoutes(app, application, fce, auth)
routes.RegisterOpenAIRoutes(app, application, fce, auth)
routes.RegisterJINARoutes(app, application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig, auth)
if !application.ApplicationConfig.DisableWebUI {
routes.RegisterUIRoutes(app, application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig, application.GalleryService, auth)
}
routes.RegisterJINARoutes(app, cl, ml, appConfig, auth)
httpFS := http.FS(embedDirStatic)

View File

@ -20,7 +20,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -205,9 +205,6 @@ var _ = Describe("API test", func() {
var cancel context.CancelFunc
var tmpdir string
var modelDir string
var bcl *config.BackendConfigLoader
var ml *model.ModelLoader
var applicationConfig *config.ApplicationConfig
commonOpts := []config.AppOption{
config.WithDebug(true),
@ -251,7 +248,7 @@ var _ = Describe("API test", func() {
},
}
bcl, ml, applicationConfig, err = startup.Startup(
application, err := startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithGalleries(galleries),
@ -260,7 +257,7 @@ var _ = Describe("API test", func() {
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@ -607,7 +604,7 @@ var _ = Describe("API test", func() {
},
}
bcl, ml, applicationConfig, err = startup.Startup(
application, err := startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithAudioDir(tmpdir),
@ -618,7 +615,7 @@ var _ = Describe("API test", func() {
config.WithBackendAssetsOutput(tmpdir))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@ -738,14 +735,14 @@ var _ = Describe("API test", func() {
var err error
bcl, ml, applicationConfig, err = startup.Startup(
application, err := startup.Startup(
append(commonOpts,
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
config.WithContext(c),
config.WithModelPath(modelPath),
)...)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
@ -1024,14 +1021,14 @@ var _ = Describe("API test", func() {
c, cancel = context.WithCancel(context.Background())
var err error
bcl, ml, applicationConfig, err = startup.Startup(
application, err := startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithModelPath(modelPath),
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
app, err = App(application)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")

View File

@ -0,0 +1,80 @@
package ctx
// This needs to be in a distinct package to avoid cycles between http, routes, and endpoints!
import (
"context"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// This type largely exists to drop permissions from ModelLoader -
// various endpoint functions do not need access to the real "ModelLoader" api, and this type makes that clear.
type FiberContentExtractor struct {
ml *model.ModelLoader
appConfig *config.ApplicationConfig
}
func NewFiberContentExtractor(ml *model.ModelLoader, appConfig *config.ApplicationConfig) *FiberContentExtractor {
return &FiberContentExtractor{
ml: ml,
appConfig: appConfig,
}
}
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func (fce *FiberContentExtractor) ModelFromContext(ctx *fiber.Ctx, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && fce.ml.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := fce.ml.ListModels()
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelInput = bearer
}
return modelInput, nil
}
func (fce *FiberContentExtractor) OpenAIRequestFromContext(ctx *fiber.Ctx, firstModel bool) (*schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// Get input data from the request body
if err := ctx.BodyParser(input); err != nil {
return nil, fmt.Errorf("failed parsing request body: %w", err)
}
context, cancel := context.WithCancel(fce.appConfig.Context)
input.Context = context
input.Cancel = cancel
modelName, err := fce.ModelFromContext(ctx, input.Model, firstModel)
input.Model = modelName
return input, err
}

View File

@ -1,43 +0,0 @@
package fiberContext
import (
"fmt"
"strings"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
// ModelFromContext returns the model from the context
// If no model is specified, it will take the first available
// Takes a model string as input which should be the one received from the user request.
// It returns the model name resolved from the context and an error if any.
func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
if ctx.Params("model") != "" {
modelInput = ctx.Params("model")
}
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelInput == "" && !bearerExists && firstModel {
models, _ := loader.ListModels()
if len(models) > 0 {
modelInput = models[0]
log.Debug().Msgf("No model specified, using: %s", modelInput)
} else {
log.Debug().Msgf("No model specified, returning error")
return "", fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelInput = bearer
}
return modelInput, nil
}

View File

@ -1,10 +1,10 @@
package elevenlabs
import (
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
@ -17,7 +17,7 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/text-to-speech/{voice-id} [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TTSEndpoint(ttsbs *backend.TextToSpeechBackendService, fce *ctx.FiberContentExtractor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.ElevenLabsTTSRequest)
@ -28,34 +28,30 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
modelFile, err := fce.ModelFromContext(c, input.ModelID, false)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.ModelID
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
} else {
if input.ModelID != "" {
modelFile = input.ModelID
} else {
modelFile = cfg.Model
}
}
log.Debug().Msgf("Request for model: %s", modelFile)
log.Debug().Str("modelName", modelFile).Msg("elevenlabs TTS request recieved for model")
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
ttsRequest := &schema.TTSRequest{
Model: input.ModelID,
Input: input.Text,
Voice: voiceID,
}
jr := ttsbs.TextToAudioFile(ttsRequest)
filePathPtr, err := jr.Wait()
if err != nil {
return err
}
return c.Download(filePath)
if filePathPtr == nil {
err := fmt.Errorf("recieved a nil filepath from TextToAudioFile")
log.Error().Err(err).Msg("eleventlabs TTSEndpoint error")
return err
}
return c.Download(*filePathPtr)
}
}

View File

@ -4,7 +4,7 @@ import (
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
@ -28,7 +28,9 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
tempFCE := ctx.NewFiberContentExtractor(ml, appConfig)
modelFile, err := tempFCE.ModelFromContext(c, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)

View File

@ -1,10 +1,10 @@
package localai
import (
"fmt"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
@ -16,45 +16,41 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/audio/speech [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func TTSEndpoint(ttsbs *backend.TextToSpeechBackendService, fce *ctx.FiberContentExtractor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TTSRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
log.Error().Err(err).Msg("Error during BodyParser")
return err
}
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
modelFile, err := fce.ModelFromContext(c, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
log.Warn().Str("input.Model", input.Model).Msg("Model not found in context, using input.Model")
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)
if input.Backend != "" {
cfg.Backend = input.Backend
log.Debug().Str("initial input.Model", input.Model).Str("modelFile", modelFile).Msg("overwriting input.Model with modelFile")
input.Model = modelFile
}
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
log.Debug().Str("modelName", modelFile).Msg("localai TTS request recieved for model")
jr := ttsbs.TextToAudioFile(input)
log.Debug().Msg("Obtained JobResult, waiting")
filePathPtr, err := jr.Wait()
if err != nil {
log.Error().Err(err).Msg("Error during TextToAudioFile")
return err
}
return c.Download(filePath)
if filePathPtr == nil {
err := fmt.Errorf("recieved a nil filepath from TextToAudioFile")
log.Error().Err(err).Msg("localai TTSEndpoint error")
return err
}
log.Debug().Str("filePath", *filePathPtr).Msg("Successfully created output audio file at filePath")
return c.Download(*filePathPtr)
}
}

View File

@ -1,16 +1,10 @@
package openai
import (
"encoding/json"
"fmt"
"time"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
@ -21,63 +15,27 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func EmbeddingsEndpoint(ebs *backend.EmbeddingsBackendService, fce *ctx.FiberContentExtractor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, ml, appConfig, true)
request, err := fce.OpenAIRequestFromContext(c, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
return fmt.Errorf("failed reading parameters from request: %w", err)
}
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
jr := ebs.Embeddings(request)
resp, err := jr.Wait()
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
log.Error().Err(err).Msg("error during embedding")
return err
}
log.Debug().Msgf("Parameter Config: %+v", config)
items := []schema.Item{}
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
if resp == nil {
err := fmt.Errorf("recieved a nil response from embeddings backend")
log.Error().Err(err).Msg("EmbeddingsEndpoint nil result")
return err
}
for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: items,
Object: "list",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
return c.JSON(*resp)
}
}

View File

@ -2,18 +2,13 @@ package openai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/go-skynet/LocalAI/core/config"
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/functions"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@ -28,253 +23,22 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi
received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
context, cancel := context.WithCancel(o.Context)
input.Context = context
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
// TEMPORARY STUB DURING DEVELOPMENT
fce := ctx.NewFiberContentExtractor(ml, o)
modelFile, err := fce.ModelFromContext(c, input.Model, firstModel)
return modelFile, input, err
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func getBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/...;base64,", drop it
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
for _, prefix := range dropPrefix {
if strings.HasPrefix(s, prefix) {
return strings.ReplaceAll(s, prefix, ""), nil
}
}
return "", fmt.Errorf("not valid string")
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := getBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
log.Error().Msgf("Failed encoding image: %s", err)
}
}
}
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
config.LoadOptionDebug(debug),
config.LoadOptionThreads(threads),
@ -283,7 +47,7 @@ func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *c
)
// Set the parameters for the language model prediction
updateRequestConfig(cfg, input)
cfg.UpdateFromOpenAIRequest(input)
return cfg, input, err
}

View File

@ -1,19 +1,18 @@
package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func RegisterElevenLabsRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
ttsbs *backend.TextToSpeechBackendService,
fce *ctx.FiberContentExtractor,
auth func(*fiber.Ctx) error) {
// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(ttsbs, fce))
}

View File

@ -3,8 +3,8 @@ package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/jina"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)

View File

@ -1,27 +1,24 @@
package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/swagger"
)
func RegisterLocalAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
galleryService *services.GalleryService,
application *core.Application,
fce *ctx.FiberContentExtractor,
auth func(*fiber.Ctx) error) {
app.Get("/swagger/*", swagger.HandlerDefault) // default
// LocalAI API endpoints
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(application.ApplicationConfig.Galleries, application.ApplicationConfig.ModelPath, application.GalleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())
@ -32,14 +29,13 @@ func RegisterLocalAIRoutes(app *fiber.App,
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/tts", auth, localai.TTSEndpoint(application.TextToSpeechBackendService, fce))
// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
// Stores : TODO IS THIS REALLY A SERVICE? OR IS IT PURELY WEB API FEATURE?
app.Post("/stores/set", auth, localai.StoresSetEndpoint(application.StoresLoader, application.ApplicationConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(application.StoresLoader, application.ApplicationConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(application.StoresLoader, application.ApplicationConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(application.StoresLoader, application.ApplicationConfig))
// Kubernetes health checks
ok := func(c *fiber.Ctx) error {
@ -51,10 +47,8 @@ func RegisterLocalAIRoutes(app *fiber.App,
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
// Experimental Backend Statistics Module
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService))
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(application.BackendMonitorService))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(application.BackendMonitorService))
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {

View File

@ -1,88 +1,85 @@
package routes
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core"
"github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)
func RegisterOpenAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
application *core.Application,
fce *ctx.FiberContentExtractor,
auth func(*fiber.Ctx) error) {
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/v1/edits", auth, openai.EditEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/edits", auth, openai.EditEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
// assistant
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/completions", auth, openai.CompletionEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(application.EmbeddingsBackendService, fce))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(application.EmbeddingsBackendService, fce))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(application.EmbeddingsBackendService, fce))
// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(application.TextToSpeechBackendService, fce))
// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
if appConfig.ImageDir != "" {
app.Static("/generated-images", appConfig.ImageDir)
if application.ApplicationConfig.ImageDir != "" {
app.Static("/generated-images", application.ApplicationConfig.ImageDir)
}
if appConfig.AudioDir != "" {
app.Static("/generated-audio", appConfig.AudioDir)
if application.ApplicationConfig.AudioDir != "" {
app.Static("/generated-audio", application.ApplicationConfig.AudioDir)
}
// models
tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS))
app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS))
app.Get("/v1/models", auth, openai.ListModelsEndpoint(application.ListModelsService))
app.Get("/models", auth, openai.ListModelsEndpoint(application.ListModelsService))
}

View File

@ -21,6 +21,11 @@ type TTSRequest struct {
Backend string `json:"backend" yaml:"backend"`
}
type RerankRequest struct {
JINARerankRequest
Backend string `json:"backend" yaml:"backend"`
}
type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

View File

@ -5,6 +5,7 @@ import (
"os"
"github.com/go-skynet/LocalAI/core"
"github.com/go-skynet/LocalAI/core/backend"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
@ -15,10 +16,10 @@ import (
"github.com/rs/zerolog/log"
)
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
options := config.NewApplicationConfig(opts...)
func Startup(opts ...config.AppOption) (*core.Application, error) {
appConfig := config.NewApplicationConfig(opts...)
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", appConfig.Threads, appConfig.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
caps, err := xsysinfo.CPUCapabilities()
if err == nil {
@ -33,77 +34,78 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
}
// Make sure directories exists
if options.ModelPath == "" {
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
if appConfig.ModelPath == "" {
return nil, fmt.Errorf("options.ModelPath cannot be empty")
}
err = os.MkdirAll(options.ModelPath, 0750)
err = os.MkdirAll(appConfig.ModelPath, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
}
if options.ImageDir != "" {
err := os.MkdirAll(options.ImageDir, 0750)
if appConfig.ImageDir != "" {
err := os.MkdirAll(appConfig.ImageDir, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
}
}
if options.AudioDir != "" {
err := os.MkdirAll(options.AudioDir, 0750)
if appConfig.AudioDir != "" {
err := os.MkdirAll(appConfig.AudioDir, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
}
}
if options.UploadDir != "" {
err := os.MkdirAll(options.UploadDir, 0750)
if appConfig.UploadDir != "" {
err := os.MkdirAll(appConfig.UploadDir, 0750)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
return nil, fmt.Errorf("unable to create UploadDir: %q", err)
}
}
//
pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...)
// TODO DAVE INSPECT HERE
pkgStartup.PreloadModelsConfigurations(appConfig.ModelLibraryURL, appConfig.ModelPath, appConfig.ModelsURL...)
cl := config.NewBackendConfigLoader()
ml := model.NewModelLoader(options.ModelPath)
app := createApplication(appConfig)
configLoaderOpts := options.ToConfigLoaderOptions()
configLoaderOpts := appConfig.ToConfigLoaderOptions()
if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
if err := app.BackendConfigLoader.LoadBackendConfigsFromPath(appConfig.ModelPath, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config files")
}
if options.ConfigFile != "" {
if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil {
if appConfig.ConfigFile != "" {
if err := app.BackendConfigLoader.LoadBackendConfigFile(appConfig.ConfigFile, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config file")
}
}
if err := cl.Preload(options.ModelPath); err != nil {
if err := app.BackendConfigLoader.Preload(appConfig.ModelPath); err != nil {
log.Error().Err(err).Msg("error downloading models")
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, nil, err
if appConfig.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(appConfig.ModelPath, appConfig.PreloadJSONModels, app.BackendConfigLoader, appConfig.Galleries); err != nil {
return nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, nil, err
if appConfig.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(appConfig.ModelPath, appConfig.PreloadModelsFromPath, app.BackendConfigLoader, appConfig.Galleries); err != nil {
return nil, err
}
}
if options.Debug {
for _, v := range cl.ListBackendConfigs() {
cfg, _ := cl.GetBackendConfig(v)
if appConfig.Debug {
for _, v := range app.BackendConfigLoader.ListBackendConfigs() {
cfg, _ := app.BackendConfigLoader.GetBackendConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
if appConfig.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
err := assets.ExtractFiles(appConfig.BackendAssets, appConfig.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", appConfig.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
@ -111,25 +113,25 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
<-appConfig.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
err := ml.StopAllGRPC()
err := app.ModelLoader.StopAllGRPC()
if err != nil {
log.Error().Err(err).Msg("error while stopping all grpc backends")
}
}()
if options.WatchDog {
if appConfig.WatchDog {
wd := model.NewWatchDog(
ml,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
ml.SetWatchDog(wd)
app.ModelLoader,
appConfig.WatchDogBusyTimeout,
appConfig.WatchDogIdleTimeout,
appConfig.WatchDogBusy,
appConfig.WatchDogIdle)
app.ModelLoader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
<-appConfig.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
@ -137,14 +139,14 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
// Watch the configuration directory
// If the directory does not exist, we don't watch it
configHandler := newConfigFileHandler(options)
configHandler := newConfigFileHandler(appConfig)
err = configHandler.Watch()
if err != nil {
log.Error().Err(err).Msg("error establishing configuration directory watcher")
}
log.Info().Msg("core/startup process completed!")
return cl, ml, options, nil
return app, nil
}
// In Lieu of a proper DI framework, this function wires up the Application manually.
@ -154,18 +156,23 @@ func createApplication(appConfig *config.ApplicationConfig) *core.Application {
ApplicationConfig: appConfig,
BackendConfigLoader: config.NewBackendConfigLoader(),
ModelLoader: model.NewModelLoader(appConfig.ModelPath),
StoresLoader: model.NewModelLoader(""),
}
var err error
// app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.RerankBackendService = backend.NewRerankBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.GalleryService = services.NewGalleryService(app.ApplicationConfig.ModelPath)
app.GalleryService.Start(app.ApplicationConfig.Context, app.BackendConfigLoader)
app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)

View File

@ -3,4 +3,5 @@ vars {
PORT: 8080
DEFAULT_MODEL: gpt-3.5-turbo
PROTOCOL: http://
DEFAULT_TTS_MODEL: voice-en-us-kathleen-low
}

View File

@ -16,7 +16,7 @@ headers {
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"model": "bert-embeddings",
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?"
}
}

View File

@ -16,7 +16,8 @@ headers {
body:json {
{
"model": "{{DEFAULT_MODEL}}",
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?"
"model": "{{DEFAULT_TTS_MODEL}}",
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?",
"backend": "piper"
}
}

10
go.mod
View File

@ -8,6 +8,7 @@ require (
github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf
github.com/Masterminds/sprig/v3 v3.2.3
github.com/charmbracelet/glamour v0.7.0
github.com/chasefleming/elem-go v0.25.0
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
github.com/fsnotify/fsnotify v1.7.0
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
@ -18,10 +19,12 @@ require (
github.com/gofiber/swagger v1.0.0
github.com/gofiber/template/html/v2 v2.1.1
github.com/google/uuid v1.5.0
github.com/hashicorp/go-multierror v1.1.1
github.com/hpcloud/tail v1.0.0
github.com/imdario/mergo v0.3.16
github.com/jaypipes/ghw v0.12.0
github.com/klauspost/cpuid/v2 v2.2.7
github.com/mholt/archiver/v3 v3.5.1
github.com/microcosm-cc/bluemonday v1.0.26
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231022042237-c25dc5193530
@ -74,7 +77,6 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chasefleming/elem-go v0.25.0 // indirect
github.com/containerd/continuity v0.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
@ -97,15 +99,12 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/huandu/xstrings v1.3.3 // indirect
github.com/jaypipes/ghw v0.12.0 // indirect
github.com/jaypipes/pcidb v1.0.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/klauspost/pgzip v1.2.5 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/microcosm-cc/bluemonday v1.0.26 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
github.com/mitchellh/copystructure v1.0.0 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
@ -158,7 +157,6 @@ require (
github.com/gofiber/contrib/fiberzerolog v1.0.0
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect

16
go.sum
View File

@ -149,14 +149,8 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
@ -299,8 +293,6 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
@ -391,8 +383,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
@ -411,8 +401,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -450,16 +438,12 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q=
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -0,0 +1,64 @@
package concurrency
import (
"sync"
)
// This is a Read-ONLY structure that contains the result of an arbitrary asynchronous action
type JobResult[RequestType any, ResultType any] struct {
Request *RequestType
Result *ResultType
Error error
mu *sync.Mutex
done *chan struct{}
}
// This structure is returned in a pair with a JobResult and serves as the structure that has access to be updated.
type WritableJobResult[RequestType any, ResultType any] struct {
*JobResult[RequestType, ResultType]
}
// Wait blocks until the result is ready and then returns the result.
// Returns *ResultType instead of ResultType since its possible we have only an error and nil for ResultType.
// Is this correct and idiomatic?
func (jr *JobResult[RequestType, ResultType]) Wait() (*ResultType, error) {
if jr.done == nil { // If the channel is blanked out, result is ready.
return jr.Result, jr.Error
}
<-*jr.done // Wait for the result to be ready
jr.mu.Lock()
defer func() {
jr.done = nil
jr.mu.Unlock()
}()
if jr.Error != nil {
return nil, jr.Error
}
return jr.Result, nil
}
// This is the function that actually updates the Result and Error on the JobResult... but it's normally not accessible
func (jr *JobResult[RequestType, ResultType]) setResult(result ResultType, err error) {
jr.mu.Lock()
defer jr.mu.Unlock()
jr.Result = &result
jr.Error = err
close(*jr.done) // Signal that the result is ready
}
// Only the WritableJobResult can actually call setResult - prevents accidental corruption
func (wjr *WritableJobResult[RequestType, ResultType]) SetResult(result ResultType, err error) {
wjr.JobResult.setResult(result, err)
}
// NewJobResult binds a request to a matched pair of JobResult and WritableJobResult
func NewJobResult[RequestType any, ResultType any](request RequestType) (*JobResult[RequestType, ResultType], *WritableJobResult[RequestType, ResultType]) {
mu := &sync.Mutex{}
done := make(chan struct{})
jr := &JobResult[RequestType, ResultType]{
mu: mu,
Request: &request,
done: &done,
}
return jr, &WritableJobResult[RequestType, ResultType]{JobResult: jr}
}

View File

@ -0,0 +1,53 @@
package concurrency_test
import (
"fmt"
"time"
. "github.com/go-skynet/LocalAI/pkg/concurrency"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("pkg/concurrency unit tests", func() {
It("can be used to recieve a result across goroutines", func() {
jr, wjr := NewJobResult[string, string]("foo")
Expect(jr).ToNot(BeNil())
Expect(wjr).ToNot(BeNil())
stallChannel := make(chan struct{})
go func(jr *JobResult[string, string]) {
resPtr, err := jr.Wait()
Expect(err).To(BeNil())
Expect(jr.Request).ToNot(BeNil())
Expect(*jr.Request).To(Equal("foo"))
Expect(resPtr).ToNot(BeNil())
Expect(*resPtr).To(Equal("bar"))
close(stallChannel)
}(jr)
go func(wjr *WritableJobResult[string, string]) {
time.Sleep(time.Second * 5)
wjr.SetResult("bar", nil)
}(wjr)
<-stallChannel
})
It("can be used to recieve an error across goroutines", func() {
jr, wjr := NewJobResult[string, string]("foo")
Expect(jr).ToNot(BeNil())
Expect(wjr).ToNot(BeNil())
stallChannel := make(chan struct{})
go func(jr *JobResult[string, string]) {
_, err := jr.Wait()
Expect(jr.Request).To(BeNil())
Expect(*jr.Request).To(Equal("foo"))
Expect(err).ToNot(BeNil())
Expect(err).To(MatchError("test"))
close(stallChannel)
}(jr)
go func(wjr *WritableJobResult[string, string]) {
time.Sleep(time.Second * 5)
wjr.SetResult("", fmt.Errorf("test"))
}(wjr)
<-stallChannel
})
})

View File

@ -207,10 +207,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
if err != nil {
return "", fmt.Errorf("could not load model: %w", err)
return "", fmt.Errorf("backend %q could not load model: %w", backend, err)
}
if !res.Success {
return "", fmt.Errorf("could not load model (no success): %s", res.Message)
return "", fmt.Errorf("backend %q could not load model (no success): %s", backend, res.Message)
}
return client, nil

View File

@ -42,9 +42,13 @@ func GetImageURLAsBase64(s string) (string, error) {
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
// if the string instead is prefixed with "data:image/...;base64,", drop it
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
for _, prefix := range dropPrefix {
if strings.HasPrefix(s, prefix) {
return strings.ReplaceAll(s, prefix, ""), nil
}
}
return "", fmt.Errorf("not valid string")
}
}

View File

@ -7,13 +7,20 @@ import (
)
var _ = Describe("utils/base64 tests", func() {
It("GetImageURLAsBase64 can strip data url prefixes", func() {
It("GetImageURLAsBase64 can strip jpeg data url prefixes", func() {
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
input := "data:image/jpeg;base64,FOO"
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).To(Equal("FOO"))
})
It("GetImageURLAsBase64 can strip png data url prefixes", func() {
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
input := "data:image/png;base64,BAR"
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).To(Equal("BAR"))
})
It("GetImageURLAsBase64 returns an error for bogus data", func() {
input := "FOO"
b64, err := GetImageURLAsBase64(input)

View File

@ -40,7 +40,9 @@ var _ = BeforeSuite(func() {
if apiEndpoint == "" {
startDockerImage()
defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://localhost:" + apiPort + "/v1"
apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely.
defaultConfig.BaseURL = apiEndpoint
} else {
fmt.Println("Default ", apiEndpoint)
defaultConfig = openai.DefaultConfig(apiKey)