diff --git a/api/api.go b/api/api.go index de18e182..57cf968f 100644 --- a/api/api.go +++ b/api/api.go @@ -9,6 +9,7 @@ import ( "github.com/go-skynet/LocalAI/api/localai" "github.com/go-skynet/LocalAI/api/openai" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" @@ -104,8 +105,8 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // Send custom error page return ctx.Status(code).JSON( - openai.ErrorResponse{ - Error: &openai.APIError{Message: err.Error(), Code: code}, + schema.ErrorResponse{ + Error: &schema.APIError{Message: err.Error(), Code: code}, }, ) }, diff --git a/api/backend/transcript.go b/api/backend/transcript.go index fbc2b7ec..77427839 100644 --- a/api/backend/transcript.go +++ b/api/backend/transcript.go @@ -5,14 +5,14 @@ import ( "fmt" config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" model "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) { +func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) { opts := modelOpts(c, o, []model.Option{ model.WithBackendString(model.WhisperBackend), diff --git a/api/openai/chat.go b/api/openai/chat.go index 6393e5d8..9c4a956f 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -10,6 +10,7 @@ import ( "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -21,20 +22,20 @@ import ( func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { emptyMessage := "" - process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - initialMessage := OpenAIResponse{ + process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + initialMessage := schema.OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Role: "assistant", Content: &emptyMessage}}}, + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, Object: "chat.completion.chunk", } responses <- initialMessage - ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool { - resp := OpenAIResponse{ + ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, + Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, Object: "chat.completion.chunk", - Usage: OpenAIUsage{ + Usage: schema.OpenAIUsage{ PromptTokens: usage.Prompt, CompletionTokens: usage.Completion, TotalTokens: usage.Prompt + usage.Completion, @@ -236,13 +237,13 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } if toStream { - responses := make(chan OpenAIResponse) + responses := make(chan schema.OpenAIResponse) go process(predInput, input, config, o.Loader, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - usage := &OpenAIUsage{} + usage := &schema.OpenAIUsage{} for ev := range responses { usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it @@ -259,13 +260,13 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) w.Flush() } - resp := &OpenAIResponse{ + resp := &schema.OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ + Choices: []schema.Choice{ { FinishReason: "stop", Index: 0, - Delta: &Message{Content: &emptyMessage}, + Delta: &schema.Message{Content: &emptyMessage}, }}, Object: "chat.completion.chunk", Usage: *usage, @@ -279,7 +280,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return nil } - result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) { + result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { if processFunctions { // As we have to change the result before processing, we can't stream the answer (yet?) ss := map[string]interface{}{} @@ -313,7 +314,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) message = backend.Finetune(*config, predInput, message) log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) + *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}}) return } } @@ -336,28 +337,28 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &fineTunedResponse}}) + *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}}) } else { // otherwise reply with the function call - *c = append(*c, Choice{ + *c = append(*c, schema.Choice{ FinishReason: "function_call", - Message: &Message{Role: "assistant", FunctionCall: ss}, + Message: &schema.Message{Role: "assistant", FunctionCall: ss}, }) } return } - *c = append(*c, Choice{FinishReason: "stop", Index: 0, Message: &Message{Role: "assistant", Content: &s}}) + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) }, nil) if err != nil { return err } - resp := &OpenAIResponse{ + resp := &schema.OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "chat.completion", - Usage: OpenAIUsage{ + Usage: schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, diff --git a/api/openai/completion.go b/api/openai/completion.go index 20d15d4e..00ddd910 100644 --- a/api/openai/completion.go +++ b/api/openai/completion.go @@ -10,6 +10,7 @@ import ( "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -18,18 +19,18 @@ import ( // https://platform.openai.com/docs/api-reference/completions func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string, usage backend.TokenUsage) bool { - resp := OpenAIResponse{ + process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ + Choices: []schema.Choice{ { Index: 0, Text: s, }, }, Object: "text_completion", - Usage: OpenAIUsage{ + Usage: schema.OpenAIUsage{ PromptTokens: usage.Prompt, CompletionTokens: usage.Completion, TotalTokens: usage.Prompt + usage.Completion, @@ -90,7 +91,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Template found, input modified to: %s", predInput) } - responses := make(chan OpenAIResponse) + responses := make(chan schema.OpenAIResponse) go process(predInput, input, config, o.Loader, responses) @@ -106,9 +107,9 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe w.Flush() } - resp := &OpenAIResponse{ + resp := &schema.OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ + Choices: []schema.Choice{ { Index: 0, FinishReason: "stop", @@ -125,7 +126,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe return nil } - var result []Choice + var result []schema.Choice totalTokenUsage := backend.TokenUsage{} @@ -140,9 +141,10 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Template found, input modified to: %s", i) } - r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k}) - }, nil) + r, tokenUsage, err := ComputeChoices( + input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) + }, nil) if err != nil { return err } @@ -153,11 +155,11 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe result = append(result, r...) } - resp := &OpenAIResponse{ + resp := &schema.OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "text_completion", - Usage: OpenAIUsage{ + Usage: schema.OpenAIUsage{ PromptTokens: totalTokenUsage.Prompt, CompletionTokens: totalTokenUsage.Completion, TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, diff --git a/api/openai/edit.go b/api/openai/edit.go index 6b4664df..8a89ab28 100644 --- a/api/openai/edit.go +++ b/api/openai/edit.go @@ -7,8 +7,10 @@ import ( "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" ) @@ -32,7 +34,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) templateFile = config.TemplateConfig.Edit } - var result []Choice + var result []schema.Choice totalTokenUsage := backend.TokenUsage{} for _, i := range config.InputStrings { @@ -47,8 +49,8 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) log.Debug().Msgf("Template found, input modified to: %s", i) } - r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) + r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) }, nil) if err != nil { return err @@ -60,11 +62,11 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) result = append(result, r...) } - resp := &OpenAIResponse{ + resp := &schema.OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "edit", - Usage: OpenAIUsage{ + Usage: schema.OpenAIUsage{ PromptTokens: totalTokenUsage.Prompt, CompletionTokens: totalTokenUsage.Completion, TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go index 7d47060d..37b0a079 100644 --- a/api/openai/embeddings.go +++ b/api/openai/embeddings.go @@ -6,6 +6,8 @@ import ( "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/api/options" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -25,7 +27,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe } log.Debug().Msgf("Parameter Config: %+v", config) - items := []Item{} + items := []schema.Item{} for i, s := range config.InputToken { // get the model function to call for the result @@ -38,7 +40,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe if err != nil { return err } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) } for i, s := range config.InputStrings { @@ -52,10 +54,10 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe if err != nil { return err } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) } - resp := &OpenAIResponse{ + resp := &schema.OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Data: items, Object: "list", diff --git a/api/openai/image.go b/api/openai/image.go index e5b32e82..9ab8fd3a 100644 --- a/api/openai/image.go +++ b/api/openai/image.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/go-skynet/LocalAI/api/schema" "os" "path/filepath" "strconv" @@ -100,7 +101,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx b64JSON = true } // src and clip_skip - var result []Item + var result []schema.Item for _, i := range config.PromptStrings { n := input.N if input.N == 0 { @@ -155,7 +156,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx return err } - item := &Item{} + item := &schema.Item{} if b64JSON { defer os.RemoveAll(output) @@ -173,7 +174,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx } } - resp := &OpenAIResponse{ + resp := &schema.OpenAIResponse{ Data: result, } diff --git a/api/openai/inference.go b/api/openai/inference.go index 2f34d82e..d835c167 100644 --- a/api/openai/inference.go +++ b/api/openai/inference.go @@ -4,12 +4,20 @@ import ( "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" model "github.com/go-skynet/LocalAI/pkg/model" ) -func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string, backend.TokenUsage) bool) ([]Choice, backend.TokenUsage, error) { +func ComputeChoices( + req *schema.OpenAIRequest, + predInput string, + config *config.Config, + o *options.Option, + loader *model.ModelLoader, + cb func(string, *[]schema.Choice), + tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { n := req.N // number of completions to return - result := []Choice{} + result := []schema.Choice{} if n == 0 { n = 1 diff --git a/api/openai/list.go b/api/openai/list.go index 59159921..8bc5bbe2 100644 --- a/api/openai/list.go +++ b/api/openai/list.go @@ -4,6 +4,7 @@ import ( "regexp" config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) @@ -16,7 +17,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func } var mm map[string]interface{} = map[string]interface{}{} - dataModels := []OpenAIModel{} + dataModels := []schema.OpenAIModel{} var filterFn func(name string) bool filter := c.Query("filter") @@ -45,7 +46,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func } if filterFn(c.Name) { - dataModels = append(dataModels, OpenAIModel{ID: c.Name, Object: "model"}) + dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) } } @@ -53,13 +54,13 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func for _, m := range models { // And only adds them if they shouldn't be skipped. if _, exists := mm[m]; !exists && filterFn(m) { - dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) + dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) } } return c.JSON(struct { - Object string `json:"object"` - Data []OpenAIModel `json:"data"` + Object string `json:"object"` + Data []schema.OpenAIModel `json:"data"` }{ Object: "list", Data: dataModels, diff --git a/api/openai/request.go b/api/openai/request.go index 46e500b8..ef4d7f6f 100644 --- a/api/openai/request.go +++ b/api/openai/request.go @@ -10,14 +10,15 @@ import ( config "github.com/go-skynet/LocalAI/api/config" options "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/api/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) -func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *OpenAIRequest, error) { +func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *schema.OpenAIRequest, error) { loader := o.Loader - input := new(OpenAIRequest) + input := new(schema.OpenAIRequest) ctx, cancel := context.WithCancel(o.Context) input.Context = ctx input.Cancel = cancel @@ -60,7 +61,7 @@ func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *Open return modelFile, input, nil } -func updateConfig(config *config.Config, input *OpenAIRequest) { +func updateConfig(config *config.Config, input *schema.OpenAIRequest) { if input.Echo { config.Echo = input.Echo } @@ -218,7 +219,7 @@ func updateConfig(config *config.Config, input *OpenAIRequest) { } } -func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { +func readConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) { // Load a config file if present after the model name modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") diff --git a/api/openai/api.go b/api/schema/openai.go similarity index 99% rename from api/openai/api.go rename to api/schema/openai.go index 2e3f154c..639125fa 100644 --- a/api/openai/api.go +++ b/api/schema/openai.go @@ -1,4 +1,4 @@ -package openai +package schema import ( "context" diff --git a/pkg/grpc/whisper/api/api.go b/api/schema/whisper.go similarity index 95% rename from pkg/grpc/whisper/api/api.go rename to api/schema/whisper.go index 700d80e7..41413c1f 100644 --- a/pkg/grpc/whisper/api/api.go +++ b/api/schema/whisper.go @@ -1,4 +1,4 @@ -package api +package schema import "time" diff --git a/cmd/grpc/bert-embeddings/main.go b/cmd/grpc/bert-embeddings/main.go index 008c30d5..90fae8b2 100644 --- a/cmd/grpc/bert-embeddings/main.go +++ b/cmd/grpc/bert-embeddings/main.go @@ -5,8 +5,8 @@ package main import ( "flag" + bert "github.com/go-skynet/LocalAI/pkg/backend/llm/bert" grpc "github.com/go-skynet/LocalAI/pkg/grpc" - bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert" ) var ( diff --git a/cmd/grpc/bloomz/main.go b/cmd/grpc/bloomz/main.go index 7348cab0..8d6303ba 100644 --- a/cmd/grpc/bloomz/main.go +++ b/cmd/grpc/bloomz/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz" + bloomz "github.com/go-skynet/LocalAI/pkg/backend/llm/bloomz" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/dolly/main.go b/cmd/grpc/dolly/main.go index 43bba92f..cc040cb6 100644 --- a/cmd/grpc/dolly/main.go +++ b/cmd/grpc/dolly/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/falcon-ggml/main.go b/cmd/grpc/falcon-ggml/main.go index 677c660d..5a99a915 100644 --- a/cmd/grpc/falcon-ggml/main.go +++ b/cmd/grpc/falcon-ggml/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/falcon/main.go b/cmd/grpc/falcon/main.go index 9ccead4d..8ddf6236 100644 --- a/cmd/grpc/falcon/main.go +++ b/cmd/grpc/falcon/main.go @@ -7,7 +7,7 @@ package main import ( "flag" - falcon "github.com/go-skynet/LocalAI/pkg/grpc/llm/falcon" + falcon "github.com/go-skynet/LocalAI/pkg/backend/llm/falcon" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/gpt2/main.go b/cmd/grpc/gpt2/main.go index d9fe2752..42f6ba47 100644 --- a/cmd/grpc/gpt2/main.go +++ b/cmd/grpc/gpt2/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/gpt4all/main.go b/cmd/grpc/gpt4all/main.go index a784d401..e659fe77 100644 --- a/cmd/grpc/gpt4all/main.go +++ b/cmd/grpc/gpt4all/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - gpt4all "github.com/go-skynet/LocalAI/pkg/grpc/llm/gpt4all" + gpt4all "github.com/go-skynet/LocalAI/pkg/backend/llm/gpt4all" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/gptj/main.go b/cmd/grpc/gptj/main.go index 27d82104..3530a727 100644 --- a/cmd/grpc/gptj/main.go +++ b/cmd/grpc/gptj/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/gptneox/main.go b/cmd/grpc/gptneox/main.go index 3d005ca8..d4e8be47 100644 --- a/cmd/grpc/gptneox/main.go +++ b/cmd/grpc/gptneox/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/langchain-huggingface/main.go b/cmd/grpc/langchain-huggingface/main.go index ab965848..938908d2 100644 --- a/cmd/grpc/langchain-huggingface/main.go +++ b/cmd/grpc/langchain-huggingface/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain" + langchain "github.com/go-skynet/LocalAI/pkg/backend/llm/langchain" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/llama/main.go b/cmd/grpc/llama/main.go index d75ef481..442f84ef 100644 --- a/cmd/grpc/llama/main.go +++ b/cmd/grpc/llama/main.go @@ -7,7 +7,7 @@ package main import ( "flag" - llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama" + llama "github.com/go-skynet/LocalAI/pkg/backend/llm/llama" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/mpt/main.go b/cmd/grpc/mpt/main.go index 58456a7d..6f46d187 100644 --- a/cmd/grpc/mpt/main.go +++ b/cmd/grpc/mpt/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/piper/main.go b/cmd/grpc/piper/main.go index 7de80e24..d77189ae 100644 --- a/cmd/grpc/piper/main.go +++ b/cmd/grpc/piper/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - tts "github.com/go-skynet/LocalAI/pkg/grpc/tts" + tts "github.com/go-skynet/LocalAI/pkg/backend/tts" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/replit/main.go b/cmd/grpc/replit/main.go index aed67fbc..0f54b6ab 100644 --- a/cmd/grpc/replit/main.go +++ b/cmd/grpc/replit/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/rwkv/main.go b/cmd/grpc/rwkv/main.go index f050a7c5..74724dc2 100644 --- a/cmd/grpc/rwkv/main.go +++ b/cmd/grpc/rwkv/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv" + rwkv "github.com/go-skynet/LocalAI/pkg/backend/llm/rwkv" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/stablediffusion/main.go b/cmd/grpc/stablediffusion/main.go index 76b4a5af..d12fe3c5 100644 --- a/cmd/grpc/stablediffusion/main.go +++ b/cmd/grpc/stablediffusion/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - image "github.com/go-skynet/LocalAI/pkg/grpc/image" + image "github.com/go-skynet/LocalAI/pkg/backend/image" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/starcoder/main.go b/cmd/grpc/starcoder/main.go index 2847acf7..c08cef20 100644 --- a/cmd/grpc/starcoder/main.go +++ b/cmd/grpc/starcoder/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + transformers "github.com/go-skynet/LocalAI/pkg/backend/llm/transformers" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/cmd/grpc/whisper/main.go b/cmd/grpc/whisper/main.go index 8d4a5fea..4896c4f9 100644 --- a/cmd/grpc/whisper/main.go +++ b/cmd/grpc/whisper/main.go @@ -5,7 +5,7 @@ package main import ( "flag" - transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe" + transcribe "github.com/go-skynet/LocalAI/pkg/backend/transcribe" grpc "github.com/go-skynet/LocalAI/pkg/grpc" ) diff --git a/pkg/grpc/image/stablediffusion.go b/pkg/backend/image/stablediffusion.go similarity index 98% rename from pkg/grpc/image/stablediffusion.go rename to pkg/backend/image/stablediffusion.go index 600f2007..50d299da 100644 --- a/pkg/grpc/image/stablediffusion.go +++ b/pkg/backend/image/stablediffusion.go @@ -9,7 +9,7 @@ import ( ) type StableDiffusion struct { - base.Base + base.SingleThread stablediffusion *stablediffusion.StableDiffusion } diff --git a/pkg/grpc/llm/bert/bert.go b/pkg/backend/llm/bert/bert.go similarity index 75% rename from pkg/grpc/llm/bert/bert.go rename to pkg/backend/llm/bert/bert.go index abdf0102..423ff79c 100644 --- a/pkg/grpc/llm/bert/bert.go +++ b/pkg/backend/llm/bert/bert.go @@ -4,32 +4,23 @@ package bert // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( bert "github.com/go-skynet/go-bert.cpp" - "github.com/rs/zerolog/log" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ) type Embeddings struct { - base.Base + base.SingleThread bert *bert.Bert } func (llm *Embeddings) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("bert backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := bert.New(opts.ModelFile) llm.bert = model return err } func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - llm.Base.Lock() - defer llm.Base.Unlock() if len(opts.EmbeddingTokens) > 0 { tokens := []int{} diff --git a/pkg/grpc/llm/bloomz/bloomz.go b/pkg/backend/llm/bloomz/bloomz.go similarity index 80% rename from pkg/grpc/llm/bloomz/bloomz.go rename to pkg/backend/llm/bloomz/bloomz.go index 304bab30..0775c77d 100644 --- a/pkg/grpc/llm/bloomz/bloomz.go +++ b/pkg/backend/llm/bloomz/bloomz.go @@ -7,24 +7,17 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" "github.com/go-skynet/bloomz.cpp" ) type LLM struct { - base.Base + base.SingleThread bloomz *bloomz.Bloomz } func (llm *LLM) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("bloomz backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := bloomz.New(opts.ModelFile) llm.bloomz = model return err @@ -47,16 +40,11 @@ func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() - return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() - go func() { res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -65,7 +53,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro } results <- res close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/backend/llm/falcon/falcon.go similarity index 92% rename from pkg/grpc/llm/falcon/falcon.go rename to pkg/backend/llm/falcon/falcon.go index c2b9d03f..4b96b71f 100644 --- a/pkg/grpc/llm/falcon/falcon.go +++ b/pkg/backend/llm/falcon/falcon.go @@ -7,25 +7,17 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" ggllm "github.com/mudler/go-ggllm.cpp" ) type LLM struct { - base.Base + base.SingleThread falcon *ggllm.Falcon } func (llm *LLM) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() - ggllmOpts := []ggllm.ModelOption{} if opts.ContextSize != 0 { ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize))) @@ -126,13 +118,10 @@ func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() predictOptions := buildPredictOptions(opts) @@ -150,7 +139,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro fmt.Println("err: ", err) } close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/gpt4all/gpt4all.go b/pkg/backend/llm/gpt4all/gpt4all.go similarity index 82% rename from pkg/grpc/llm/gpt4all/gpt4all.go rename to pkg/backend/llm/gpt4all/gpt4all.go index 0b485120..86d4baa3 100644 --- a/pkg/grpc/llm/gpt4all/gpt4all.go +++ b/pkg/backend/llm/gpt4all/gpt4all.go @@ -8,23 +8,15 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" - "github.com/rs/zerolog/log" ) type LLM struct { - base.Base + base.SingleThread gpt4all *gpt4all.Model } func (llm *LLM) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("gpt4all backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() - model, err := gpt4all.New(opts.ModelFile, gpt4all.SetThreads(int(opts.Threads)), gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath)) @@ -47,15 +39,10 @@ func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() - return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() - predictOptions := buildPredictOptions(opts) go func() { @@ -69,7 +56,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro } llm.gpt4all.SetTokenCallback(nil) close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/langchain/langchain.go b/pkg/backend/llm/langchain/langchain.go similarity index 82% rename from pkg/grpc/llm/langchain/langchain.go rename to pkg/backend/llm/langchain/langchain.go index cd3fd12b..5d5f94bd 100644 --- a/pkg/grpc/llm/langchain/langchain.go +++ b/pkg/backend/llm/langchain/langchain.go @@ -8,7 +8,6 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/langchain" - "github.com/rs/zerolog/log" ) type LLM struct { @@ -19,21 +18,12 @@ type LLM struct { } func (llm *LLM) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("langchain backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() llm.langchain, _ = langchain.NewHuggingFace(opts.Model) llm.model = opts.Model return nil } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() - o := []langchain.PredictOption{ langchain.SetModel(llm.model), langchain.SetMaxTokens(int(opts.Tokens)), @@ -48,7 +38,6 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() o := []langchain.PredictOption{ langchain.SetModel(llm.model), langchain.SetMaxTokens(int(opts.Tokens)), @@ -63,7 +52,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro } results <- res.Completion close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/llama/llama.go b/pkg/backend/llm/llama/llama.go similarity index 93% rename from pkg/grpc/llm/llama/llama.go rename to pkg/backend/llm/llama/llama.go index 594dfb97..62040233 100644 --- a/pkg/grpc/llm/llama/llama.go +++ b/pkg/backend/llm/llama/llama.go @@ -8,24 +8,15 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/go-llama.cpp" - "github.com/rs/zerolog/log" ) type LLM struct { - base.Base + base.SingleThread llama *llama.LLama } func (llm *LLM) Load(opts *pb.ModelOptions) error { - - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("llama backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() - ropeFreqBase := float32(10000) ropeFreqScale := float32(1) @@ -176,14 +167,10 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() - predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { @@ -197,16 +184,12 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro fmt.Println("err: ", err) } close(results) - llm.Base.Unlock() }() return nil } func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - llm.Base.Lock() - defer llm.Base.Unlock() - predictOptions := buildPredictOptions(opts) if len(opts.EmbeddingTokens) > 0 { @@ -221,9 +204,6 @@ func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { } func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { - llm.Base.Lock() - defer llm.Base.Unlock() - predictOptions := buildPredictOptions(opts) l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...) if err != nil { diff --git a/pkg/grpc/llm/rwkv/rwkv.go b/pkg/backend/llm/rwkv/rwkv.go similarity index 83% rename from pkg/grpc/llm/rwkv/rwkv.go rename to pkg/backend/llm/rwkv/rwkv.go index 3658befb..bd7142ef 100644 --- a/pkg/grpc/llm/rwkv/rwkv.go +++ b/pkg/backend/llm/rwkv/rwkv.go @@ -9,24 +9,17 @@ import ( "github.com/donomii/go-rwkv.cpp" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" ) const tokenizerSuffix = ".tokenizer.json" type LLM struct { - base.Base + base.SingleThread rwkv *rwkv.RwkvState } func (llm *LLM) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("rwkv backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() modelPath := filepath.Dir(opts.ModelFile) modelFile := filepath.Base(opts.ModelFile) model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads())) @@ -39,9 +32,6 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() - stopWord := "\n" if len(opts.StopPrompts) > 0 { stopWord = opts.StopPrompts[0] @@ -57,7 +47,6 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { stopWord := "\n" @@ -75,7 +64,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro return true }) close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/transformers/dolly.go b/pkg/backend/llm/transformers/dolly.go similarity index 75% rename from pkg/grpc/llm/transformers/dolly.go rename to pkg/backend/llm/transformers/dolly.go index 220490a7..b3579b04 100644 --- a/pkg/grpc/llm/transformers/dolly.go +++ b/pkg/backend/llm/transformers/dolly.go @@ -7,38 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type Dolly struct { - base.Base + base.SingleThread dolly *transformers.Dolly } func (llm *Dolly) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("dolly backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.NewDolly(opts.ModelFile) llm.dolly = model return err } func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -48,7 +38,6 @@ func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) er } results <- res close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/transformers/falcon.go b/pkg/backend/llm/transformers/falcon.go similarity index 74% rename from pkg/grpc/llm/transformers/falcon.go rename to pkg/backend/llm/transformers/falcon.go index fceb10c4..5299fb02 100644 --- a/pkg/grpc/llm/transformers/falcon.go +++ b/pkg/backend/llm/transformers/falcon.go @@ -7,38 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type Falcon struct { - base.Base + base.SingleThread falcon *transformers.Falcon } func (llm *Falcon) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("transformers-falcon backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.NewFalcon(opts.ModelFile) llm.falcon = model return err } func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -47,7 +37,6 @@ func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) e } results <- res close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/llm/transformers/gpt2.go b/pkg/backend/llm/transformers/gpt2.go similarity index 75% rename from pkg/grpc/llm/transformers/gpt2.go rename to pkg/backend/llm/transformers/gpt2.go index 53b364e9..ab162a76 100644 --- a/pkg/grpc/llm/transformers/gpt2.go +++ b/pkg/backend/llm/transformers/gpt2.go @@ -7,38 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type GPT2 struct { - base.Base + base.SingleThread gpt2 *transformers.GPT2 } func (llm *GPT2) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("gpt2 backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.New(opts.ModelFile) llm.gpt2 = model return err } func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -47,7 +37,6 @@ func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) err } results <- res close(results) - llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/gptj.go b/pkg/backend/llm/transformers/gptj.go similarity index 75% rename from pkg/grpc/llm/transformers/gptj.go rename to pkg/backend/llm/transformers/gptj.go index c798c3df..f00f1044 100644 --- a/pkg/grpc/llm/transformers/gptj.go +++ b/pkg/backend/llm/transformers/gptj.go @@ -7,38 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type GPTJ struct { - base.Base + base.SingleThread gptj *transformers.GPTJ } func (llm *GPTJ) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("gptj backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.NewGPTJ(opts.ModelFile) llm.gptj = model return err } func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -47,7 +37,6 @@ func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) err } results <- res close(results) - llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/gptneox.go b/pkg/backend/llm/transformers/gptneox.go similarity index 75% rename from pkg/grpc/llm/transformers/gptneox.go rename to pkg/backend/llm/transformers/gptneox.go index bcaa8da6..a06d910e 100644 --- a/pkg/grpc/llm/transformers/gptneox.go +++ b/pkg/backend/llm/transformers/gptneox.go @@ -7,38 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type GPTNeoX struct { - base.Base + base.SingleThread gptneox *transformers.GPTNeoX } func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("gptneox backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.NewGPTNeoX(opts.ModelFile) llm.gptneox = model return err } func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -47,7 +37,6 @@ func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) } results <- res close(results) - llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/mpt.go b/pkg/backend/llm/transformers/mpt.go similarity index 75% rename from pkg/grpc/llm/transformers/mpt.go rename to pkg/backend/llm/transformers/mpt.go index 1b9272ee..f6e0a143 100644 --- a/pkg/grpc/llm/transformers/mpt.go +++ b/pkg/backend/llm/transformers/mpt.go @@ -7,39 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type MPT struct { - base.Base + base.SingleThread mpt *transformers.MPT } func (llm *MPT) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("mpt backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.NewMPT(opts.ModelFile) llm.mpt = model return err } func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() - return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -48,7 +37,6 @@ func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) erro } results <- res close(results) - llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/predict.go b/pkg/backend/llm/transformers/predict.go similarity index 100% rename from pkg/grpc/llm/transformers/predict.go rename to pkg/backend/llm/transformers/predict.go diff --git a/pkg/grpc/llm/transformers/replit.go b/pkg/backend/llm/transformers/replit.go similarity index 75% rename from pkg/grpc/llm/transformers/replit.go rename to pkg/backend/llm/transformers/replit.go index 0c1fc066..a979edcb 100644 --- a/pkg/grpc/llm/transformers/replit.go +++ b/pkg/backend/llm/transformers/replit.go @@ -7,38 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type Replit struct { - base.Base + base.SingleThread replit *transformers.Replit } func (llm *Replit) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("replit backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.NewReplit(opts.ModelFile) llm.replit = model return err } func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -47,7 +37,6 @@ func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) e } results <- res close(results) - llm.Base.Unlock() }() return nil } diff --git a/pkg/grpc/llm/transformers/starcoder.go b/pkg/backend/llm/transformers/starcoder.go similarity index 75% rename from pkg/grpc/llm/transformers/starcoder.go rename to pkg/backend/llm/transformers/starcoder.go index c63256f9..25a758a0 100644 --- a/pkg/grpc/llm/transformers/starcoder.go +++ b/pkg/backend/llm/transformers/starcoder.go @@ -7,38 +7,28 @@ import ( "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type Starcoder struct { - base.Base + base.SingleThread starcoder *transformers.Starcoder } func (llm *Starcoder) Load(opts *pb.ModelOptions) error { - if llm.Base.State != pb.StatusResponse_UNINITIALIZED { - log.Warn().Msgf("starcoder backend loading %s while already in state %s!", opts.Model, llm.Base.State.String()) - } - - llm.Base.Lock() - defer llm.Base.Unlock() model, err := transformers.NewStarcoder(opts.ModelFile) llm.starcoder = model return err } func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { - llm.Base.Lock() - defer llm.Base.Unlock() return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error { - llm.Base.Lock() go func() { res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -47,7 +37,6 @@ func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string } results <- res close(results) - llm.Base.Unlock() }() return nil diff --git a/pkg/grpc/whisper/whisper.go b/pkg/backend/transcribe/transcript.go similarity index 87% rename from pkg/grpc/whisper/whisper.go rename to pkg/backend/transcribe/transcript.go index 806e1452..e428b324 100644 --- a/pkg/grpc/whisper/whisper.go +++ b/pkg/backend/transcribe/transcript.go @@ -1,4 +1,4 @@ -package whisper +package transcribe import ( "fmt" @@ -7,8 +7,8 @@ import ( "path/filepath" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - wav "github.com/go-audio/wav" - "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" + "github.com/go-audio/wav" + "github.com/go-skynet/LocalAI/api/schema" ) func sh(c string) (string, error) { @@ -29,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (api.Result, error) { - res := api.Result{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { + res := schema.Result{} dir, err := os.MkdirTemp("", "whisper") if err != nil { @@ -90,7 +90,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( tokens = append(tokens, t.Id) } - segment := api.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} + segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} res.Segments = append(res.Segments, segment) res.Text += s.Text diff --git a/pkg/grpc/transcribe/whisper.go b/pkg/backend/transcribe/whisper.go similarity index 74% rename from pkg/grpc/transcribe/whisper.go rename to pkg/backend/transcribe/whisper.go index 3d25b050..493d8229 100644 --- a/pkg/grpc/transcribe/whisper.go +++ b/pkg/backend/transcribe/whisper.go @@ -4,14 +4,13 @@ package transcribe // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - whisperutil "github.com/go-skynet/LocalAI/pkg/grpc/whisper" - "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" ) type Whisper struct { - base.Base + base.SingleThread whisper whisper.Model } @@ -22,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error { return err } -func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (api.Result, error) { - return whisperutil.Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { + return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) } diff --git a/pkg/grpc/tts/piper.go b/pkg/backend/tts/piper.go similarity index 98% rename from pkg/grpc/tts/piper.go rename to pkg/backend/tts/piper.go index e5e99fc8..41f10049 100644 --- a/pkg/grpc/tts/piper.go +++ b/pkg/backend/tts/piper.go @@ -13,7 +13,7 @@ import ( ) type Piper struct { - base.Base + base.SingleThread piper *PiperB } diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index ffce63c7..739d1cbb 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -5,34 +5,32 @@ package base import ( "fmt" "os" - "sync" + "github.com/go-skynet/LocalAI/api/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" gopsutil "github.com/shirou/gopsutil/v3/process" ) +// Base is a base class for all backends to implement +// Note: the backends that does not support multiple requests +// should use SingleThread instead type Base struct { - backendBusy sync.Mutex - State pb.StatusResponse_State } -func (llm *Base) Busy() bool { - r := llm.backendBusy.TryLock() - if r { - llm.backendBusy.Unlock() - } - return r +func (llm *Base) Locking() bool { + return false } func (llm *Base) Lock() { - llm.backendBusy.Lock() - llm.State = pb.StatusResponse_BUSY + panic("not implemented") } func (llm *Base) Unlock() { - llm.State = pb.StatusResponse_READY - llm.backendBusy.Unlock() + panic("not implemented") +} + +func (llm *Base) Busy() bool { + return false } func (llm *Base) Load(opts *pb.ModelOptions) error { @@ -55,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) { - return api.Result{}, fmt.Errorf("unimplemented") +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { + return schema.Result{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { @@ -69,7 +67,12 @@ func (llm *Base) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationRespons // backends may wish to call this to capture the gopsutil info, then enhance with additional memory usage details? func (llm *Base) Status() (pb.StatusResponse, error) { + return pb.StatusResponse{ + Memory: memoryUsage(), + }, nil +} +func memoryUsage() *pb.MemoryUsageData { mud := pb.MemoryUsageData{ Breakdown: make(map[string]uint64), } @@ -85,9 +88,5 @@ func (llm *Base) Status() (pb.StatusResponse, error) { mud.Breakdown["gopsutil-RSS"] = memInfo.RSS } } - - return pb.StatusResponse{ - State: llm.State, - Memory: &mud, - }, nil + return &mud } diff --git a/pkg/grpc/base/singlethread.go b/pkg/grpc/base/singlethread.go new file mode 100644 index 00000000..91aa26ee --- /dev/null +++ b/pkg/grpc/base/singlethread.go @@ -0,0 +1,52 @@ +package base + +import ( + "sync" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" +) + +// SingleThread are backends that does not support multiple requests. +// There will be only one request being served at the time. +// This is useful for models that are not thread safe and cannot run +// multiple requests at the same time. +type SingleThread struct { + Base + backendBusy sync.Mutex +} + +// Locking returns true if the backend needs to lock resources +func (llm *SingleThread) Locking() bool { + return true +} + +func (llm *SingleThread) Lock() { + llm.backendBusy.Lock() +} + +func (llm *SingleThread) Unlock() { + llm.backendBusy.Unlock() +} + +func (llm *SingleThread) Busy() bool { + r := llm.backendBusy.TryLock() + if r { + llm.backendBusy.Unlock() + } + return r +} + +// backends may wish to call this to capture the gopsutil info, then enhance with additional memory usage details? +func (llm *SingleThread) Status() (pb.StatusResponse, error) { + mud := memoryUsage() + + state := pb.StatusResponse_READY + if llm.Busy() { + state = pb.StatusResponse_BUSY + } + + return pb.StatusResponse{ + State: state, + Memory: mud, + }, nil +} diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index d69251ff..0697ac69 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -7,8 +7,8 @@ import ( "sync" "time" + "github.com/go-skynet/LocalAI/api/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -158,7 +158,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { c.setBusy(true) defer c.setBusy(false) conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -171,14 +171,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques if err != nil { return nil, err } - tresult := &api.Result{} + tresult := &schema.Result{} for _, s := range res.Segments { tks := []int{} for _, t := range s.Tokens { tks = append(tks, int(t)) } tresult.Segments = append(tresult.Segments, - api.Segment{ + schema.Segment{ Text: s.Text, Id: int(s.Id), Start: time.Duration(s.Start), diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 6c46f764..a76261c1 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -1,18 +1,21 @@ package grpc import ( + "github.com/go-skynet/LocalAI/api/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" ) type LLM interface { Busy() bool + Lock() + Unlock() + Locking() bool Predict(*pb.PredictOptions) (string, error) PredictStream(*pb.PredictOptions, chan string) error Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (api.Result, error) + AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) TTS(*pb.TTSRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 58ea4e7e..24dbe098 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -30,6 +30,10 @@ func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, e } func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } embeds, err := s.llm.Embeddings(in) if err != nil { return nil, err @@ -39,6 +43,10 @@ func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.Embe } func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } err := s.llm.Load(in) if err != nil { return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err @@ -47,11 +55,19 @@ func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result } func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } result, err := s.llm.Predict(in) return newReply(result), err } func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } err := s.llm.GenerateImage(in) if err != nil { return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err @@ -60,6 +76,10 @@ func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) } func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } err := s.llm.TTS(in) if err != nil { return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err @@ -68,6 +88,10 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) } func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } result, err := s.llm.AudioTranscription(in) if err != nil { return nil, err @@ -93,7 +117,10 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques } func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { - + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } resultChan := make(chan string) done := make(chan bool) @@ -111,6 +138,10 @@ func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictS } func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } res, err := s.llm.TokenizeString(in) if err != nil { return nil, err