mirror of https://github.com/mudler/LocalAI.git
Merge 8419b10c4e
into 6559ac11b1
This commit is contained in:
commit
60d2f24fcc
|
@ -31,7 +31,7 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
|||
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))
|
||||
|
||||
if model == nil {
|
||||
return fmt.Errorf("could not load model")
|
||||
return fmt.Errorf("rwkv could not load model")
|
||||
}
|
||||
llm.rwkv = model
|
||||
return nil
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
@ -17,13 +18,15 @@ type Application struct {
|
|||
// Core Low-Level Services
|
||||
BackendConfigLoader *config.BackendConfigLoader
|
||||
ModelLoader *model.ModelLoader
|
||||
StoresLoader *model.ModelLoader
|
||||
|
||||
// Backend Services
|
||||
// EmbeddingsBackendService *backend.EmbeddingsBackendService
|
||||
EmbeddingsBackendService *backend.EmbeddingsBackendService
|
||||
// ImageGenerationBackendService *backend.ImageGenerationBackendService
|
||||
// LLMBackendService *backend.LLMBackendService
|
||||
// TranscriptionBackendService *backend.TranscriptionBackendService
|
||||
// TextToSpeechBackendService *backend.TextToSpeechBackendService
|
||||
TextToSpeechBackendService *backend.TextToSpeechBackendService
|
||||
// RerankBackendService *backend.RerankBackendService
|
||||
|
||||
// LocalAI System Services
|
||||
BackendMonitorService *services.BackendMonitorService
|
||||
|
@ -31,6 +34,7 @@ type Application struct {
|
|||
ListModelsService *services.ListModelsService
|
||||
LocalAIMetricsService *services.LocalAIMetricsService
|
||||
// OpenAIService *services.OpenAIService
|
||||
|
||||
}
|
||||
|
||||
// TODO [NEXT PR?]: Break up ApplicationConfig.
|
||||
|
|
|
@ -2,14 +2,108 @@ package backend
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
|
||||
type EmbeddingsBackendService struct {
|
||||
ml *model.ModelLoader
|
||||
bcl *config.BackendConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService {
|
||||
return &EmbeddingsBackendService{
|
||||
ml: ml,
|
||||
bcl: bcl,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) *concurrency.JobResult[*schema.OpenAIRequest, *schema.OpenAIResponse] {
|
||||
|
||||
jr, wjr := concurrency.NewJobResult[*schema.OpenAIRequest, *schema.OpenAIResponse](request)
|
||||
|
||||
go func(wjr *concurrency.WritableJobResult[*schema.OpenAIRequest, *schema.OpenAIResponse]) {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
request = *wjr.Request // TODO is needed?
|
||||
|
||||
bc, err := ebs.bcl.LoadBackendConfigFileByName(request.Model, ebs.appConfig.ModelPath,
|
||||
config.LoadOptionDebug(ebs.appConfig.Debug),
|
||||
config.LoadOptionThreads(ebs.appConfig.Threads),
|
||||
config.LoadOptionContextSize(ebs.appConfig.ContextSize),
|
||||
config.LoadOptionF16(ebs.appConfig.F16),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("modelPath", ebs.appConfig.ModelPath).Msg("unable to load backend config")
|
||||
wjr.SetResult(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
bc.UpdateFromOpenAIRequest(request)
|
||||
|
||||
items := []schema.Item{}
|
||||
|
||||
for i, s := range bc.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := ebs.modelEmbedding("", s, *bc)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Ints("numeric tokens", s).Msg("error during modelEmbedding")
|
||||
wjr.SetResult(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Ints("numeric tokens", s).Msg("error during embedFn")
|
||||
wjr.SetResult(nil, err)
|
||||
return
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
for i, s := range bc.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := ebs.modelEmbedding(s, []int{}, *bc)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("string tokens", s).Msg("error during modelEmbedding")
|
||||
wjr.SetResult(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("string tokens", s).Msg("error during embedFn")
|
||||
wjr.SetResult(nil, err)
|
||||
return
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: request.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}
|
||||
wjr.SetResult(resp, nil)
|
||||
}(wjr)
|
||||
|
||||
return jr
|
||||
}
|
||||
|
||||
func (ebs *EmbeddingsBackendService) modelEmbedding(s string, tokens []int, backendConfig config.BackendConfig) (func() ([]float32, error), error) {
|
||||
modelFile := backendConfig.Model
|
||||
|
||||
grpcOpts := gRPCModelOpts(backendConfig)
|
||||
|
@ -17,19 +111,19 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
|
|||
var inferenceModel interface{}
|
||||
var err error
|
||||
|
||||
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||
opts := modelOpts(backendConfig, ebs.appConfig, []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(*backendConfig.Threads)),
|
||||
model.WithAssetDir(appConfig.AssetsDestination),
|
||||
model.WithAssetDir(ebs.appConfig.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(appConfig.Context),
|
||||
model.WithContext(ebs.appConfig.Context),
|
||||
})
|
||||
|
||||
if backendConfig.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
inferenceModel, err = ebs.ml.GreedyLoader(opts...)
|
||||
} else {
|
||||
opts = append(opts, model.WithBackendString(backendConfig.Backend))
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
inferenceModel, err = ebs.ml.BackendLoader(opts...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -39,7 +133,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
|
|||
switch model := inferenceModel.(type) {
|
||||
case grpc.Backend:
|
||||
fn = func() ([]float32, error) {
|
||||
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
|
||||
predictOptions := gRPCPredictOpts(backendConfig, ebs.appConfig.ModelPath)
|
||||
if len(tokens) > 0 {
|
||||
embeds := []int32{}
|
||||
|
||||
|
@ -48,7 +142,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
|
|||
}
|
||||
predictOptions.EmbeddingTokens = embeds
|
||||
|
||||
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
||||
res, err := model.Embeddings(ebs.appConfig.Context, predictOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -57,7 +151,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
|
|||
}
|
||||
predictOptions.Embeddings = s
|
||||
|
||||
res, err := model.Embeddings(appConfig.Context, predictOptions)
|
||||
res, err := model.Embeddings(ebs.appConfig.Context, predictOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -7,12 +7,123 @@ import (
|
|||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
type TextToSpeechBackendService struct {
|
||||
ml *model.ModelLoader
|
||||
bcl *config.BackendConfigLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewTextToSpeechBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TextToSpeechBackendService {
|
||||
return &TextToSpeechBackendService{
|
||||
ml: ml,
|
||||
bcl: bcl,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (ttsbs *TextToSpeechBackendService) TextToAudioFile(request *schema.TTSRequest) *concurrency.JobResult[*schema.TTSRequest, string] {
|
||||
jr, wjr := concurrency.NewJobResult[*schema.TTSRequest, string](request)
|
||||
|
||||
go func(wjr *concurrency.WritableJobResult[*schema.TTSRequest, string]) {
|
||||
bc, err := ttsbs.bcl.LoadBackendConfigFileByName(request.Model, ttsbs.appConfig.ModelPath,
|
||||
config.LoadOptionDebug(ttsbs.appConfig.Debug),
|
||||
config.LoadOptionThreads(ttsbs.appConfig.Threads),
|
||||
config.LoadOptionContextSize(ttsbs.appConfig.ContextSize),
|
||||
config.LoadOptionF16(ttsbs.appConfig.F16),
|
||||
)
|
||||
if err != nil || bc == nil {
|
||||
log.Error().Err(err).Str("modelName", request.Model).Str("modelPath", ttsbs.appConfig.ModelPath).Msg("unable to load backend config")
|
||||
wjr.SetResult("", err)
|
||||
return
|
||||
}
|
||||
|
||||
if request.Backend != "" { // Allow users to specify a backend to use that overrides config.
|
||||
bc.Backend = request.Backend
|
||||
}
|
||||
// TODO consider merging the below function in, but leave it seperated for diff reasons in the first PR
|
||||
dst, err := ttsbs.modelTTS(request.Backend, request.Input, bc.Model, request.Voice, *bc)
|
||||
log.Debug().Str("dst", dst).Err(err).Msg("modelTTS result in goroutine")
|
||||
wjr.SetResult(dst, err)
|
||||
}(wjr)
|
||||
|
||||
return jr
|
||||
}
|
||||
|
||||
func (ttsbs *TextToSpeechBackendService) modelTTS(backend, text, modelFile, voice string, backendConfig config.BackendConfig) (string, error) {
|
||||
bb := backend
|
||||
if bb == "" {
|
||||
bb = model.PiperBackend
|
||||
}
|
||||
|
||||
grpcOpts := gRPCModelOpts(backendConfig)
|
||||
|
||||
opts := modelOpts(config.BackendConfig{}, ttsbs.appConfig, []model.Option{
|
||||
model.WithBackendString(bb),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(ttsbs.appConfig.Context),
|
||||
model.WithAssetDir(ttsbs.appConfig.AssetsDestination),
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
})
|
||||
ttsModel, err := ttsbs.ml.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if ttsModel == nil {
|
||||
return "", fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
if ttsbs.appConfig.AudioDir == "" {
|
||||
return "", fmt.Errorf("ApplicationConfig.AudioDir not set, cannot continue")
|
||||
}
|
||||
|
||||
// Shouldn't be needed anymore. Consider removing later
|
||||
if err := os.MkdirAll(ttsbs.appConfig.AudioDir, 0750); err != nil {
|
||||
return "", fmt.Errorf("failed` creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(ttsbs.appConfig.AudioDir, "tts", ".wav")
|
||||
filePath := filepath.Join(ttsbs.appConfig.AudioDir, fileName)
|
||||
|
||||
log.Debug().Str("filePath", filePath).Msg("computed output filePath")
|
||||
|
||||
// If the model file is not empty, we pass it joined with the model path
|
||||
modelPath := ""
|
||||
if modelFile != "" {
|
||||
// If the model file is not empty, we pass it joined with the model path
|
||||
// Checking first that it exists and is not outside ModelPath
|
||||
// TODO: we should actually first check if the modelFile is looking like
|
||||
// a FS path
|
||||
mp := filepath.Join(ttsbs.appConfig.ModelPath, modelFile)
|
||||
if _, err := os.Stat(mp); err == nil {
|
||||
if err := utils.VerifyPath(mp, ttsbs.appConfig.ModelPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
modelPath = mp
|
||||
} else {
|
||||
modelPath = modelFile
|
||||
}
|
||||
}
|
||||
|
||||
_, err = ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: filePath,
|
||||
})
|
||||
|
||||
return filePath, err
|
||||
}
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
@ -28,62 +139,3 @@ func generateUniqueFileName(dir, baseName, ext string) string {
|
|||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
||||
func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
|
||||
bb := backend
|
||||
if bb == "" {
|
||||
bb = model.PiperBackend
|
||||
}
|
||||
|
||||
grpcOpts := gRPCModelOpts(backendConfig)
|
||||
|
||||
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
|
||||
model.WithBackendString(bb),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(appConfig.Context),
|
||||
model.WithAssetDir(appConfig.AssetsDestination),
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
})
|
||||
ttsModel, err := loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if ttsModel == nil {
|
||||
return "", nil, fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
|
||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
||||
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
||||
|
||||
// If the model file is not empty, we pass it joined with the model path
|
||||
modelPath := ""
|
||||
if modelFile != "" {
|
||||
// If the model file is not empty, we pass it joined with the model path
|
||||
// Checking first that it exists and is not outside ModelPath
|
||||
// TODO: we should actually first check if the modelFile is looking like
|
||||
// a FS path
|
||||
mp := filepath.Join(loader.ModelPath, modelFile)
|
||||
if _, err := os.Stat(mp); err == nil {
|
||||
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
modelPath = mp
|
||||
} else {
|
||||
modelPath = modelFile
|
||||
}
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
Dst: filePath,
|
||||
})
|
||||
|
||||
return filePath, res, err
|
||||
}
|
||||
|
|
|
@ -126,16 +126,16 @@ func (r *RunCMD) Run(ctx *Context) error {
|
|||
}
|
||||
|
||||
if r.PreloadBackendOnly {
|
||||
_, _, _, err := startup.Startup(opts...)
|
||||
_, err := startup.Startup(opts...)
|
||||
return err
|
||||
}
|
||||
|
||||
cl, ml, options, err := startup.Startup(opts...)
|
||||
app, err := startup.Startup(opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||
}
|
||||
|
||||
appHTTP, err := http.App(cl, ml, options)
|
||||
appHTTP, err := http.App(app)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error during HTTP App construction")
|
||||
return err
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -48,20 +49,32 @@ func (t *TTSCMD) Run(ctx *Context) error {
|
|||
}
|
||||
}()
|
||||
|
||||
options := config.BackendConfig{}
|
||||
options.SetDefaults()
|
||||
request := &schema.TTSRequest{
|
||||
Backend: t.Backend,
|
||||
Input: text,
|
||||
Model: t.Model,
|
||||
Voice: t.Voice,
|
||||
}
|
||||
|
||||
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options)
|
||||
ttsbs := backend.NewTextToSpeechBackendService(ml, config.NewBackendConfigLoader(), opts)
|
||||
|
||||
jr := ttsbs.TextToAudioFile(request)
|
||||
filePathPtr, err := jr.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if filePathPtr == nil {
|
||||
err := fmt.Errorf("recieved a nil filepath from TextToAudioFile")
|
||||
log.Error().Err(err).Msg("tts cli error")
|
||||
return err
|
||||
}
|
||||
if outputFile != "" {
|
||||
if err := os.Rename(filePath, outputFile); err != nil {
|
||||
if err := os.Rename(*filePathPtr, outputFile); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Generate file %s\n", outputFile)
|
||||
} else {
|
||||
fmt.Printf("Generate file %s\n", filePath)
|
||||
fmt.Printf("Generate file %s\n", *filePathPtr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -332,3 +335,203 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
|||
cfg.Debug = &trueV
|
||||
}
|
||||
}
|
||||
|
||||
func (config *BackendConfig) UpdateFromOpenAIRequest(input *schema.OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
index := 0
|
||||
for i, m := range input.Messages {
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
for _, pp := range c {
|
||||
if pp.Type == "text" {
|
||||
input.Messages[i].StringContent = pp.Text
|
||||
} else if pp.Type == "image_url" {
|
||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||
base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL)
|
||||
if err == nil {
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
// set a placeholder for each image
|
||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
||||
index++
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed encoding image")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,16 +6,16 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/go-skynet/LocalAI/core/http/routes"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/contrib/fiberzerolog"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -24,7 +24,6 @@ import (
|
|||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
|
||||
// swagger handler
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
@ -64,11 +63,11 @@ var embedDirStatic embed.FS
|
|||
// @in header
|
||||
// @name Authorization
|
||||
|
||||
func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
|
||||
func App(application *core.Application) (*fiber.App, error) {
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
Views: renderEngine(),
|
||||
BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
BodyLimit: application.ApplicationConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: true,
|
||||
|
@ -109,7 +108,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||
|
||||
// Default middleware config
|
||||
|
||||
if !appConfig.Debug {
|
||||
if !application.ApplicationConfig.Debug {
|
||||
app.Use(recover.New())
|
||||
}
|
||||
|
||||
|
@ -127,11 +126,11 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||
|
||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||
auth := func(c *fiber.Ctx) error {
|
||||
if len(appConfig.ApiKeys) == 0 {
|
||||
if len(application.ApplicationConfig.ApiKeys) == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if len(appConfig.ApiKeys) == 0 {
|
||||
if len(application.ApplicationConfig.ApiKeys) == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
|
@ -147,7 +146,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||
}
|
||||
|
||||
apiKey := authHeaderParts[1]
|
||||
for _, key := range appConfig.ApiKeys {
|
||||
for _, key := range application.ApplicationConfig.ApiKeys {
|
||||
if apiKey == key {
|
||||
return c.Next()
|
||||
}
|
||||
|
@ -156,32 +155,36 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
|||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||
}
|
||||
|
||||
if appConfig.CORS {
|
||||
if application.ApplicationConfig.CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if appConfig.CORSAllowOrigins == "" {
|
||||
if application.ApplicationConfig.CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
|
||||
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig.CORSAllowOrigins})
|
||||
}
|
||||
|
||||
app.Use(c)
|
||||
}
|
||||
|
||||
// Load config jsons
|
||||
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||
utils.LoadConfig(application.ApplicationConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||
utils.LoadConfig(application.ApplicationConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||
utils.LoadConfig(application.ApplicationConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||
|
||||
galleryService := services.NewGalleryService(appConfig.ModelPath)
|
||||
galleryService.Start(appConfig.Context, cl)
|
||||
// Create the Fiber Content Extractor that the endpoints will use instead of the full modelLoader
|
||||
fce := ctx.NewFiberContentExtractor(application.ModelLoader, application.ApplicationConfig)
|
||||
|
||||
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth)
|
||||
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth)
|
||||
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
|
||||
if !appConfig.DisableWebUI {
|
||||
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth)
|
||||
// Register all routes - TODO: enhance for partial registration?
|
||||
// For the "large" register function, it seems to make sense to pass application directly and allow them to sort out their dependencies.
|
||||
// However, for particularly simple routes, passing dependencies directly may be more clean? Try both and experiment!
|
||||
routes.RegisterElevenLabsRoutes(app, application.TextToSpeechBackendService, fce, auth)
|
||||
routes.RegisterLocalAIRoutes(app, application, fce, auth)
|
||||
routes.RegisterOpenAIRoutes(app, application, fce, auth)
|
||||
routes.RegisterJINARoutes(app, application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig, auth)
|
||||
|
||||
if !application.ApplicationConfig.DisableWebUI {
|
||||
routes.RegisterUIRoutes(app, application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig, application.GalleryService, auth)
|
||||
}
|
||||
routes.RegisterJINARoutes(app, cl, ml, appConfig, auth)
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import (
|
|||
|
||||
"github.com/go-skynet/LocalAI/pkg/downloader"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -205,9 +205,6 @@ var _ = Describe("API test", func() {
|
|||
var cancel context.CancelFunc
|
||||
var tmpdir string
|
||||
var modelDir string
|
||||
var bcl *config.BackendConfigLoader
|
||||
var ml *model.ModelLoader
|
||||
var applicationConfig *config.ApplicationConfig
|
||||
|
||||
commonOpts := []config.AppOption{
|
||||
config.WithDebug(true),
|
||||
|
@ -251,7 +248,7 @@ var _ = Describe("API test", func() {
|
|||
},
|
||||
}
|
||||
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithGalleries(galleries),
|
||||
|
@ -260,7 +257,7 @@ var _ = Describe("API test", func() {
|
|||
config.WithBackendAssetsOutput(backendAssetsDir))...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = App(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
@ -607,7 +604,7 @@ var _ = Describe("API test", func() {
|
|||
},
|
||||
}
|
||||
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithAudioDir(tmpdir),
|
||||
|
@ -618,7 +615,7 @@ var _ = Describe("API test", func() {
|
|||
config.WithBackendAssetsOutput(tmpdir))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = App(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
@ -738,14 +735,14 @@ var _ = Describe("API test", func() {
|
|||
|
||||
var err error
|
||||
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = App(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
|
@ -1024,14 +1021,14 @@ var _ = Describe("API test", func() {
|
|||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
bcl, ml, applicationConfig, err = startup.Startup(
|
||||
application, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = App(bcl, ml, applicationConfig)
|
||||
app, err = App(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
package ctx
|
||||
|
||||
// This needs to be in a distinct package to avoid cycles between http, routes, and endpoints!
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// This type largely exists to drop permissions from ModelLoader -
|
||||
// various endpoint functions do not need access to the real "ModelLoader" api, and this type makes that clear.
|
||||
type FiberContentExtractor struct {
|
||||
ml *model.ModelLoader
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewFiberContentExtractor(ml *model.ModelLoader, appConfig *config.ApplicationConfig) *FiberContentExtractor {
|
||||
return &FiberContentExtractor{
|
||||
ml: ml,
|
||||
appConfig: appConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// ModelFromContext returns the model from the context
|
||||
// If no model is specified, it will take the first available
|
||||
// Takes a model string as input which should be the one received from the user request.
|
||||
// It returns the model name resolved from the context and an error if any.
|
||||
func (fce *FiberContentExtractor) ModelFromContext(ctx *fiber.Ctx, modelInput string, firstModel bool) (string, error) {
|
||||
if ctx.Params("model") != "" {
|
||||
modelInput = ctx.Params("model")
|
||||
}
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && fce.ml.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelInput == "" && !bearerExists && firstModel {
|
||||
models, _ := fce.ml.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelInput = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelInput)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelInput = bearer
|
||||
}
|
||||
return modelInput, nil
|
||||
}
|
||||
|
||||
func (fce *FiberContentExtractor) OpenAIRequestFromContext(ctx *fiber.Ctx, firstModel bool) (*schema.OpenAIRequest, error) {
|
||||
input := new(schema.OpenAIRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := ctx.BodyParser(input); err != nil {
|
||||
return nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
context, cancel := context.WithCancel(fce.appConfig.Context)
|
||||
input.Context = context
|
||||
input.Cancel = cancel
|
||||
|
||||
modelName, err := fce.ModelFromContext(ctx, input.Model, firstModel)
|
||||
input.Model = modelName
|
||||
|
||||
return input, err
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package fiberContext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ModelFromContext returns the model from the context
|
||||
// If no model is specified, it will take the first available
|
||||
// Takes a model string as input which should be the one received from the user request.
|
||||
// It returns the model name resolved from the context and an error if any.
|
||||
func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) {
|
||||
if ctx.Params("model") != "" {
|
||||
modelInput = ctx.Params("model")
|
||||
}
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelInput == "" && !bearerExists && firstModel {
|
||||
models, _ := loader.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelInput = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelInput)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelInput = bearer
|
||||
}
|
||||
return modelInput, nil
|
||||
}
|
|
@ -1,10 +1,10 @@
|
|||
package elevenlabs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -17,7 +17,7 @@ import (
|
|||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(ttsbs *backend.TextToSpeechBackendService, fce *ctx.FiberContentExtractor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.ElevenLabsTTSRequest)
|
||||
|
@ -28,34 +28,30 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
|
|||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false)
|
||||
modelFile, err := fce.ModelFromContext(c, input.ModelID, false)
|
||||
if err != nil {
|
||||
modelFile = input.ModelID
|
||||
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
if err != nil {
|
||||
modelFile = input.ModelID
|
||||
log.Warn().Msgf("Model not found in context: %s", input.ModelID)
|
||||
} else {
|
||||
if input.ModelID != "" {
|
||||
modelFile = input.ModelID
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
}
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
log.Debug().Str("modelName", modelFile).Msg("elevenlabs TTS request recieved for model")
|
||||
|
||||
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
|
||||
ttsRequest := &schema.TTSRequest{
|
||||
Model: input.ModelID,
|
||||
Input: input.Text,
|
||||
Voice: voiceID,
|
||||
}
|
||||
|
||||
jr := ttsbs.TextToAudioFile(ttsRequest)
|
||||
filePathPtr, err := jr.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
if filePathPtr == nil {
|
||||
err := fmt.Errorf("recieved a nil filepath from TextToAudioFile")
|
||||
log.Error().Err(err).Msg("eleventlabs TTSEndpoint error")
|
||||
return err
|
||||
}
|
||||
return c.Download(*filePathPtr)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
@ -28,7 +28,9 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||
tempFCE := ctx.NewFiberContentExtractor(ml, appConfig)
|
||||
|
||||
modelFile, err := tempFCE.ModelFromContext(c, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
@ -16,45 +16,41 @@ import (
|
|||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/audio/speech [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(ttsbs *backend.TextToSpeechBackendService, fce *ctx.FiberContentExtractor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.TTSRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
log.Error().Err(err).Msg("Error during BodyParser")
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
|
||||
modelFile, err := fce.ModelFromContext(c, input.Model, false)
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
|
||||
config.LoadOptionDebug(appConfig.Debug),
|
||||
config.LoadOptionThreads(appConfig.Threads),
|
||||
config.LoadOptionContextSize(appConfig.ContextSize),
|
||||
config.LoadOptionF16(appConfig.F16),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
log.Warn().Str("input.Model", input.Model).Msg("Model not found in context, using input.Model")
|
||||
} else {
|
||||
modelFile = cfg.Model
|
||||
}
|
||||
log.Debug().Msgf("Request for model: %s", modelFile)
|
||||
|
||||
if input.Backend != "" {
|
||||
cfg.Backend = input.Backend
|
||||
log.Debug().Str("initial input.Model", input.Model).Str("modelFile", modelFile).Msg("overwriting input.Model with modelFile")
|
||||
input.Model = modelFile
|
||||
}
|
||||
|
||||
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
|
||||
log.Debug().Str("modelName", modelFile).Msg("localai TTS request recieved for model")
|
||||
|
||||
jr := ttsbs.TextToAudioFile(input)
|
||||
log.Debug().Msg("Obtained JobResult, waiting")
|
||||
filePathPtr, err := jr.Wait()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error during TextToAudioFile")
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
if filePathPtr == nil {
|
||||
err := fmt.Errorf("recieved a nil filepath from TextToAudioFile")
|
||||
log.Error().Err(err).Msg("localai TTSEndpoint error")
|
||||
return err
|
||||
}
|
||||
log.Debug().Str("filePath", *filePathPtr).Msg("Successfully created output audio file at filePath")
|
||||
return c.Download(*filePathPtr)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,16 +1,10 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -21,63 +15,27 @@ import (
|
|||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EmbeddingsEndpoint(ebs *backend.EmbeddingsBackendService, fce *ctx.FiberContentExtractor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readRequest(c, ml, appConfig, true)
|
||||
request, err := fce.OpenAIRequestFromContext(c, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
return fmt.Errorf("failed reading parameters from request: %w", err)
|
||||
}
|
||||
|
||||
config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
|
||||
jr := ebs.Embeddings(request)
|
||||
resp, err := jr.Wait()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
log.Error().Err(err).Msg("error during embedding")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
items := []schema.Item{}
|
||||
|
||||
for i, s := range config.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
if resp == nil {
|
||||
err := fmt.Errorf("recieved a nil response from embeddings backend")
|
||||
log.Error().Err(err).Msg("EmbeddingsEndpoint nil result")
|
||||
return err
|
||||
}
|
||||
|
||||
for i, s := range config.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(*resp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,18 +2,13 @@ package openai
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/functions"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
@ -28,253 +23,22 @@ func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfi
|
|||
|
||||
received, _ := json.Marshal(input)
|
||||
|
||||
ctx, cancel := context.WithCancel(o.Context)
|
||||
input.Context = ctx
|
||||
context, cancel := context.WithCancel(o.Context)
|
||||
input.Context = context
|
||||
input.Cancel = cancel
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
|
||||
// TEMPORARY STUB DURING DEVELOPMENT
|
||||
fce := ctx.NewFiberContentExtractor(ml, o)
|
||||
|
||||
modelFile, err := fce.ModelFromContext(c, input.Model, firstModel)
|
||||
|
||||
return modelFile, input, err
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func getBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/...;base64,", drop it
|
||||
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
|
||||
for _, prefix := range dropPrefix {
|
||||
if strings.HasPrefix(s, prefix) {
|
||||
return strings.ReplaceAll(s, prefix, ""), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
index := 0
|
||||
for i, m := range input.Messages {
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
for _, pp := range c {
|
||||
if pp.Type == "text" {
|
||||
input.Messages[i].StringContent = pp.Text
|
||||
} else if pp.Type == "image_url" {
|
||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||
base64, err := getBase64Image(pp.ImageURL.URL)
|
||||
if err == nil {
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
// set a placeholder for each image
|
||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
||||
index++
|
||||
} else {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
|
||||
|
||||
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
|
||||
config.LoadOptionDebug(debug),
|
||||
config.LoadOptionThreads(threads),
|
||||
|
@ -283,7 +47,7 @@ func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *c
|
|||
)
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateRequestConfig(cfg, input)
|
||||
cfg.UpdateFromOpenAIRequest(input)
|
||||
|
||||
return cfg, input, err
|
||||
}
|
||||
|
|
|
@ -1,19 +1,18 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
ttsbs *backend.TextToSpeechBackendService,
|
||||
fce *ctx.FiberContentExtractor,
|
||||
auth func(*fiber.Ctx) error) {
|
||||
|
||||
// Elevenlabs
|
||||
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(ttsbs, fce))
|
||||
|
||||
}
|
||||
|
|
|
@ -3,8 +3,8 @@ package routes
|
|||
import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/jina"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
|
|
|
@ -1,27 +1,24 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/swagger"
|
||||
)
|
||||
|
||||
func RegisterLocalAIRoutes(app *fiber.App,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService,
|
||||
application *core.Application,
|
||||
fce *ctx.FiberContentExtractor,
|
||||
auth func(*fiber.Ctx) error) {
|
||||
|
||||
app.Get("/swagger/*", swagger.HandlerDefault) // default
|
||||
|
||||
// LocalAI API endpoints
|
||||
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(application.ApplicationConfig.Galleries, application.ApplicationConfig.ModelPath, application.GalleryService)
|
||||
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
|
||||
|
@ -32,14 +29,13 @@ func RegisterLocalAIRoutes(app *fiber.App,
|
|||
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(application.TextToSpeechBackendService, fce))
|
||||
|
||||
// Stores
|
||||
sl := model.NewModelLoader("")
|
||||
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
|
||||
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
|
||||
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
|
||||
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))
|
||||
// Stores : TODO IS THIS REALLY A SERVICE? OR IS IT PURELY WEB API FEATURE?
|
||||
app.Post("/stores/set", auth, localai.StoresSetEndpoint(application.StoresLoader, application.ApplicationConfig))
|
||||
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(application.StoresLoader, application.ApplicationConfig))
|
||||
app.Post("/stores/get", auth, localai.StoresGetEndpoint(application.StoresLoader, application.ApplicationConfig))
|
||||
app.Post("/stores/find", auth, localai.StoresFindEndpoint(application.StoresLoader, application.ApplicationConfig))
|
||||
|
||||
// Kubernetes health checks
|
||||
ok := func(c *fiber.Ctx) error {
|
||||
|
@ -51,10 +47,8 @@ func RegisterLocalAIRoutes(app *fiber.App,
|
|||
|
||||
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
|
||||
|
||||
// Experimental Backend Statistics Module
|
||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService))
|
||||
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService))
|
||||
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(application.BackendMonitorService))
|
||||
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(application.BackendMonitorService))
|
||||
|
||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
|
|
|
@ -1,88 +1,85 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core"
|
||||
"github.com/go-skynet/LocalAI/core/http/ctx"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func RegisterOpenAIRoutes(app *fiber.App,
|
||||
cl *config.BackendConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
application *core.Application,
|
||||
fce *ctx.FiberContentExtractor,
|
||||
auth func(*fiber.Ctx) error) {
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
|
||||
// assistant
|
||||
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
|
||||
// files
|
||||
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
|
||||
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
|
||||
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
|
||||
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Post("/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Get("/v1/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Get("/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(application.EmbeddingsBackendService, fce))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(application.EmbeddingsBackendService, fce))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(application.EmbeddingsBackendService, fce))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(application.TextToSpeechBackendService, fce))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig))
|
||||
|
||||
if appConfig.ImageDir != "" {
|
||||
app.Static("/generated-images", appConfig.ImageDir)
|
||||
if application.ApplicationConfig.ImageDir != "" {
|
||||
app.Static("/generated-images", application.ApplicationConfig.ImageDir)
|
||||
}
|
||||
|
||||
if appConfig.AudioDir != "" {
|
||||
app.Static("/generated-audio", appConfig.AudioDir)
|
||||
if application.ApplicationConfig.AudioDir != "" {
|
||||
app.Static("/generated-audio", application.ApplicationConfig.AudioDir)
|
||||
}
|
||||
|
||||
// models
|
||||
tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS))
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(application.ListModelsService))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(application.ListModelsService))
|
||||
}
|
||||
|
|
|
@ -21,6 +21,11 @@ type TTSRequest struct {
|
|||
Backend string `json:"backend" yaml:"backend"`
|
||||
}
|
||||
|
||||
type RerankRequest struct {
|
||||
JINARerankRequest
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
}
|
||||
|
||||
type StoresSet struct {
|
||||
Store string `json:"store,omitempty" yaml:"store,omitempty"`
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core"
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
|
@ -15,10 +16,10 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
|
||||
options := config.NewApplicationConfig(opts...)
|
||||
func Startup(opts ...config.AppOption) (*core.Application, error) {
|
||||
appConfig := config.NewApplicationConfig(opts...)
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", appConfig.Threads, appConfig.ModelPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
caps, err := xsysinfo.CPUCapabilities()
|
||||
if err == nil {
|
||||
|
@ -33,77 +34,78 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
}
|
||||
|
||||
// Make sure directories exists
|
||||
if options.ModelPath == "" {
|
||||
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||
if appConfig.ModelPath == "" {
|
||||
return nil, fmt.Errorf("options.ModelPath cannot be empty")
|
||||
}
|
||||
err = os.MkdirAll(options.ModelPath, 0750)
|
||||
|
||||
err = os.MkdirAll(appConfig.ModelPath, 0750)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||
}
|
||||
if options.ImageDir != "" {
|
||||
err := os.MkdirAll(options.ImageDir, 0750)
|
||||
if appConfig.ImageDir != "" {
|
||||
err := os.MkdirAll(appConfig.ImageDir, 0750)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||
}
|
||||
}
|
||||
if options.AudioDir != "" {
|
||||
err := os.MkdirAll(options.AudioDir, 0750)
|
||||
if appConfig.AudioDir != "" {
|
||||
err := os.MkdirAll(appConfig.AudioDir, 0750)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
||||
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
||||
}
|
||||
}
|
||||
if options.UploadDir != "" {
|
||||
err := os.MkdirAll(options.UploadDir, 0750)
|
||||
if appConfig.UploadDir != "" {
|
||||
err := os.MkdirAll(appConfig.UploadDir, 0750)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||
return nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...)
|
||||
// TODO DAVE INSPECT HERE
|
||||
pkgStartup.PreloadModelsConfigurations(appConfig.ModelLibraryURL, appConfig.ModelPath, appConfig.ModelsURL...)
|
||||
|
||||
cl := config.NewBackendConfigLoader()
|
||||
ml := model.NewModelLoader(options.ModelPath)
|
||||
app := createApplication(appConfig)
|
||||
|
||||
configLoaderOpts := options.ToConfigLoaderOptions()
|
||||
configLoaderOpts := appConfig.ToConfigLoaderOptions()
|
||||
|
||||
if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil {
|
||||
if err := app.BackendConfigLoader.LoadBackendConfigsFromPath(appConfig.ModelPath, configLoaderOpts...); err != nil {
|
||||
log.Error().Err(err).Msg("error loading config files")
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
if appConfig.ConfigFile != "" {
|
||||
if err := app.BackendConfigLoader.LoadBackendConfigFile(appConfig.ConfigFile, configLoaderOpts...); err != nil {
|
||||
log.Error().Err(err).Msg("error loading config file")
|
||||
}
|
||||
}
|
||||
|
||||
if err := cl.Preload(options.ModelPath); err != nil {
|
||||
if err := app.BackendConfigLoader.Preload(appConfig.ModelPath); err != nil {
|
||||
log.Error().Err(err).Msg("error downloading models")
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
||||
return nil, nil, nil, err
|
||||
if appConfig.PreloadJSONModels != "" {
|
||||
if err := services.ApplyGalleryFromString(appConfig.ModelPath, appConfig.PreloadJSONModels, app.BackendConfigLoader, appConfig.Galleries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
||||
return nil, nil, nil, err
|
||||
if appConfig.PreloadModelsFromPath != "" {
|
||||
if err := services.ApplyGalleryFromFile(appConfig.ModelPath, appConfig.PreloadModelsFromPath, app.BackendConfigLoader, appConfig.Galleries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
for _, v := range cl.ListBackendConfigs() {
|
||||
cfg, _ := cl.GetBackendConfig(v)
|
||||
if appConfig.Debug {
|
||||
for _, v := range app.BackendConfigLoader.ListBackendConfigs() {
|
||||
cfg, _ := app.BackendConfigLoader.GetBackendConfig(v)
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
if options.AssetsDestination != "" {
|
||||
if appConfig.AssetsDestination != "" {
|
||||
// Extract files from the embedded FS
|
||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||
err := assets.ExtractFiles(appConfig.BackendAssets, appConfig.AssetsDestination)
|
||||
log.Debug().Msgf("Extracting backend assets files to %s", appConfig.AssetsDestination)
|
||||
if err != nil {
|
||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||
}
|
||||
|
@ -111,25 +113,25 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
<-appConfig.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
err := ml.StopAllGRPC()
|
||||
err := app.ModelLoader.StopAllGRPC()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||
}
|
||||
}()
|
||||
|
||||
if options.WatchDog {
|
||||
if appConfig.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
ml,
|
||||
options.WatchDogBusyTimeout,
|
||||
options.WatchDogIdleTimeout,
|
||||
options.WatchDogBusy,
|
||||
options.WatchDogIdle)
|
||||
ml.SetWatchDog(wd)
|
||||
app.ModelLoader,
|
||||
appConfig.WatchDogBusyTimeout,
|
||||
appConfig.WatchDogIdleTimeout,
|
||||
appConfig.WatchDogBusy,
|
||||
appConfig.WatchDogIdle)
|
||||
app.ModelLoader.SetWatchDog(wd)
|
||||
go wd.Run()
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
<-appConfig.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
wd.Shutdown()
|
||||
}()
|
||||
|
@ -137,14 +139,14 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
|||
|
||||
// Watch the configuration directory
|
||||
// If the directory does not exist, we don't watch it
|
||||
configHandler := newConfigFileHandler(options)
|
||||
configHandler := newConfigFileHandler(appConfig)
|
||||
err = configHandler.Watch()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error establishing configuration directory watcher")
|
||||
}
|
||||
|
||||
log.Info().Msg("core/startup process completed!")
|
||||
return cl, ml, options, nil
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// In Lieu of a proper DI framework, this function wires up the Application manually.
|
||||
|
@ -154,18 +156,23 @@ func createApplication(appConfig *config.ApplicationConfig) *core.Application {
|
|||
ApplicationConfig: appConfig,
|
||||
BackendConfigLoader: config.NewBackendConfigLoader(),
|
||||
ModelLoader: model.NewModelLoader(appConfig.ModelPath),
|
||||
StoresLoader: model.NewModelLoader(""),
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.RerankBackendService = backend.NewRerankBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
|
||||
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
|
||||
app.GalleryService = services.NewGalleryService(app.ApplicationConfig.ModelPath)
|
||||
app.GalleryService.Start(app.ApplicationConfig.Context, app.BackendConfigLoader)
|
||||
|
||||
app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
|
||||
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
|
||||
|
||||
|
|
|
@ -3,4 +3,5 @@ vars {
|
|||
PORT: 8080
|
||||
DEFAULT_MODEL: gpt-3.5-turbo
|
||||
PROTOCOL: http://
|
||||
DEFAULT_TTS_MODEL: voice-en-us-kathleen-low
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ headers {
|
|||
|
||||
body:json {
|
||||
{
|
||||
"model": "{{DEFAULT_MODEL}}",
|
||||
"model": "bert-embeddings",
|
||||
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,8 @@ headers {
|
|||
|
||||
body:json {
|
||||
{
|
||||
"model": "{{DEFAULT_MODEL}}",
|
||||
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?"
|
||||
"model": "{{DEFAULT_TTS_MODEL}}",
|
||||
"input": "A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?",
|
||||
"backend": "piper"
|
||||
}
|
||||
}
|
||||
|
|
10
go.mod
10
go.mod
|
@ -8,6 +8,7 @@ require (
|
|||
github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf
|
||||
github.com/Masterminds/sprig/v3 v3.2.3
|
||||
github.com/charmbracelet/glamour v0.7.0
|
||||
github.com/chasefleming/elem-go v0.25.0
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
|
||||
|
@ -18,10 +19,12 @@ require (
|
|||
github.com/gofiber/swagger v1.0.0
|
||||
github.com/gofiber/template/html/v2 v2.1.1
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hpcloud/tail v1.0.0
|
||||
github.com/imdario/mergo v0.3.16
|
||||
github.com/jaypipes/ghw v0.12.0
|
||||
github.com/klauspost/cpuid/v2 v2.2.7
|
||||
github.com/mholt/archiver/v3 v3.5.1
|
||||
github.com/microcosm-cc/bluemonday v1.0.26
|
||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
|
||||
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231022042237-c25dc5193530
|
||||
|
@ -74,7 +77,6 @@ require (
|
|||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/chasefleming/elem-go v0.25.0 // indirect
|
||||
github.com/containerd/continuity v0.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.8.1 // indirect
|
||||
|
@ -97,15 +99,12 @@ require (
|
|||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/huandu/xstrings v1.3.3 // indirect
|
||||
github.com/jaypipes/ghw v0.12.0 // indirect
|
||||
github.com/jaypipes/pcidb v1.0.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/klauspost/pgzip v1.2.5 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.26 // indirect
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
||||
github.com/mitchellh/copystructure v1.0.0 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
|
@ -158,7 +157,6 @@ require (
|
|||
github.com/gofiber/contrib/fiberzerolog v1.0.0
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/hashicorp/errwrap v1.0.0 // indirect
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/klauspost/compress v1.17.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
|
|
16
go.sum
16
go.sum
|
@ -149,14 +149,8 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3
|
|||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
|
||||
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
|
||||
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||
|
@ -299,8 +293,6 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
|
|||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
|
||||
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
|
||||
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
|
||||
|
@ -391,8 +383,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
|
@ -411,8 +401,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
|||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
|
||||
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
@ -450,16 +438,12 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q=
|
||||
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package concurrency
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// This is a Read-ONLY structure that contains the result of an arbitrary asynchronous action
|
||||
type JobResult[RequestType any, ResultType any] struct {
|
||||
Request *RequestType
|
||||
Result *ResultType
|
||||
Error error
|
||||
mu *sync.Mutex
|
||||
done *chan struct{}
|
||||
}
|
||||
|
||||
// This structure is returned in a pair with a JobResult and serves as the structure that has access to be updated.
|
||||
type WritableJobResult[RequestType any, ResultType any] struct {
|
||||
*JobResult[RequestType, ResultType]
|
||||
}
|
||||
|
||||
// Wait blocks until the result is ready and then returns the result.
|
||||
// Returns *ResultType instead of ResultType since its possible we have only an error and nil for ResultType.
|
||||
// Is this correct and idiomatic?
|
||||
func (jr *JobResult[RequestType, ResultType]) Wait() (*ResultType, error) {
|
||||
if jr.done == nil { // If the channel is blanked out, result is ready.
|
||||
return jr.Result, jr.Error
|
||||
}
|
||||
<-*jr.done // Wait for the result to be ready
|
||||
jr.mu.Lock()
|
||||
defer func() {
|
||||
jr.done = nil
|
||||
jr.mu.Unlock()
|
||||
}()
|
||||
if jr.Error != nil {
|
||||
return nil, jr.Error
|
||||
}
|
||||
return jr.Result, nil
|
||||
}
|
||||
|
||||
// This is the function that actually updates the Result and Error on the JobResult... but it's normally not accessible
|
||||
func (jr *JobResult[RequestType, ResultType]) setResult(result ResultType, err error) {
|
||||
jr.mu.Lock()
|
||||
defer jr.mu.Unlock()
|
||||
jr.Result = &result
|
||||
jr.Error = err
|
||||
close(*jr.done) // Signal that the result is ready
|
||||
}
|
||||
|
||||
// Only the WritableJobResult can actually call setResult - prevents accidental corruption
|
||||
func (wjr *WritableJobResult[RequestType, ResultType]) SetResult(result ResultType, err error) {
|
||||
wjr.JobResult.setResult(result, err)
|
||||
}
|
||||
|
||||
// NewJobResult binds a request to a matched pair of JobResult and WritableJobResult
|
||||
func NewJobResult[RequestType any, ResultType any](request RequestType) (*JobResult[RequestType, ResultType], *WritableJobResult[RequestType, ResultType]) {
|
||||
mu := &sync.Mutex{}
|
||||
done := make(chan struct{})
|
||||
jr := &JobResult[RequestType, ResultType]{
|
||||
mu: mu,
|
||||
Request: &request,
|
||||
done: &done,
|
||||
}
|
||||
return jr, &WritableJobResult[RequestType, ResultType]{JobResult: jr}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package concurrency_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/concurrency"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("pkg/concurrency unit tests", func() {
|
||||
It("can be used to recieve a result across goroutines", func() {
|
||||
jr, wjr := NewJobResult[string, string]("foo")
|
||||
Expect(jr).ToNot(BeNil())
|
||||
Expect(wjr).ToNot(BeNil())
|
||||
stallChannel := make(chan struct{})
|
||||
go func(jr *JobResult[string, string]) {
|
||||
resPtr, err := jr.Wait()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(jr.Request).ToNot(BeNil())
|
||||
Expect(*jr.Request).To(Equal("foo"))
|
||||
Expect(resPtr).ToNot(BeNil())
|
||||
Expect(*resPtr).To(Equal("bar"))
|
||||
close(stallChannel)
|
||||
}(jr)
|
||||
go func(wjr *WritableJobResult[string, string]) {
|
||||
time.Sleep(time.Second * 5)
|
||||
wjr.SetResult("bar", nil)
|
||||
}(wjr)
|
||||
<-stallChannel
|
||||
})
|
||||
|
||||
It("can be used to recieve an error across goroutines", func() {
|
||||
jr, wjr := NewJobResult[string, string]("foo")
|
||||
Expect(jr).ToNot(BeNil())
|
||||
Expect(wjr).ToNot(BeNil())
|
||||
stallChannel := make(chan struct{})
|
||||
go func(jr *JobResult[string, string]) {
|
||||
_, err := jr.Wait()
|
||||
Expect(jr.Request).To(BeNil())
|
||||
Expect(*jr.Request).To(Equal("foo"))
|
||||
Expect(err).ToNot(BeNil())
|
||||
Expect(err).To(MatchError("test"))
|
||||
close(stallChannel)
|
||||
}(jr)
|
||||
go func(wjr *WritableJobResult[string, string]) {
|
||||
time.Sleep(time.Second * 5)
|
||||
wjr.SetResult("", fmt.Errorf("test"))
|
||||
}(wjr)
|
||||
<-stallChannel
|
||||
})
|
||||
})
|
|
@ -207,10 +207,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||
|
||||
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not load model: %w", err)
|
||||
return "", fmt.Errorf("backend %q could not load model: %w", backend, err)
|
||||
}
|
||||
if !res.Success {
|
||||
return "", fmt.Errorf("could not load model (no success): %s", res.Message)
|
||||
return "", fmt.Errorf("backend %q could not load model (no success): %s", backend, res.Message)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
|
|
|
@ -42,9 +42,13 @@ func GetImageURLAsBase64(s string) (string, error) {
|
|||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
// if the string instead is prefixed with "data:image/...;base64,", drop it
|
||||
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
|
||||
for _, prefix := range dropPrefix {
|
||||
if strings.HasPrefix(s, prefix) {
|
||||
return strings.ReplaceAll(s, prefix, ""), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,13 +7,20 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("utils/base64 tests", func() {
|
||||
It("GetImageURLAsBase64 can strip data url prefixes", func() {
|
||||
It("GetImageURLAsBase64 can strip jpeg data url prefixes", func() {
|
||||
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
|
||||
input := "data:image/jpeg;base64,FOO"
|
||||
b64, err := GetImageURLAsBase64(input)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(b64).To(Equal("FOO"))
|
||||
})
|
||||
It("GetImageURLAsBase64 can strip png data url prefixes", func() {
|
||||
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
|
||||
input := "data:image/png;base64,BAR"
|
||||
b64, err := GetImageURLAsBase64(input)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(b64).To(Equal("BAR"))
|
||||
})
|
||||
It("GetImageURLAsBase64 returns an error for bogus data", func() {
|
||||
input := "FOO"
|
||||
b64, err := GetImageURLAsBase64(input)
|
||||
|
|
|
@ -40,7 +40,9 @@ var _ = BeforeSuite(func() {
|
|||
if apiEndpoint == "" {
|
||||
startDockerImage()
|
||||
defaultConfig = openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = "http://localhost:" + apiPort + "/v1"
|
||||
apiEndpoint = "http://localhost:" + apiPort + "/v1" // So that other tests can reference this value safely.
|
||||
defaultConfig.BaseURL = apiEndpoint
|
||||
|
||||
} else {
|
||||
fmt.Println("Default ", apiEndpoint)
|
||||
defaultConfig = openai.DefaultConfig(apiKey)
|
||||
|
|
Loading…
Reference in New Issue