From af9e5a2d05d477eedaf1bff08370208d2b4a9d86 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 17 Apr 2024 23:33:49 +0200 Subject: [PATCH] Revert #1963 (#2056) * Revert "fix(fncall): fix regression introduced in #1963 (#2048)" This reverts commit 6b06d4e0af4db7a8aa8e131ec2b3af171934862e. * Revert "fix: action-tmate back to upstream, dead code removal (#2038)" This reverts commit fdec8a9d00a034ccd8e075008edd165147edf328. * Revert "feat(grpc): return consumed token count and update response accordingly (#2035)" This reverts commit e843d7df0e8b177ab122a9f7bfa7196274ccd204. * Revert "refactor: backend/service split, channel-based llm flow (#1963)" This reverts commit eed5706994a3e770a0194cad9d1cfd724ba1b10a. * feat(grpc): return consumed token count and update response accordingly Fixes: #1920 Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- .github/workflows/test.yml | 15 +- Makefile | 18 +- backend/go/transcribe/transcript.go | 6 +- backend/go/transcribe/whisper.go | 2 +- core/backend/embeddings.go | 90 +- core/backend/image.go | 259 +----- core/backend/llm.go | 271 ++---- core/backend/options.go | 84 +- core/backend/transcript.go | 41 +- core/backend/tts.go | 77 +- core/cli/run.go | 8 +- core/cli/transcript.go | 19 +- core/cli/tts.go | 26 +- core/config/backend_config.go | 301 ++++++- core/config/backend_config_loader.go | 509 ----------- core/config/exports_test.go | 6 - core/http/api.go | 227 ++--- core/http/api_test.go | 98 +-- core/http/ctx/fiber.go | 65 +- core/http/endpoints/elevenlabs/tts.go | 39 +- .../http/endpoints/localai/backend_monitor.go | 4 +- core/http/endpoints/localai/tts.go | 39 +- core/http/endpoints/openai/assistant.go | 2 +- core/http/endpoints/openai/chat.go | 621 ++++++++++++-- core/http/endpoints/openai/completion.go | 163 +++- core/http/endpoints/openai/edit.go | 78 +- core/http/endpoints/openai/embeddings.go | 65 +- core/http/endpoints/openai/image.go | 216 ++++- core/http/endpoints/openai/inference.go | 55 ++ core/http/endpoints/openai/list.go | 52 +- core/http/endpoints/openai/request.go | 285 ++++++ core/http/endpoints/openai/transcription.go | 28 +- core/schema/{transcription.go => whisper.go} | 2 +- core/services/backend_monitor.go | 30 +- core/services/gallery.go | 116 +-- core/services/list_models.go | 72 -- core/services/openai.go | 808 ------------------ core/startup/startup.go | 91 +- core/state.go | 41 - .../llm text/-completions Stream.bru | 25 - pkg/concurrency/concurrency.go | 135 --- pkg/concurrency/concurrency_test.go | 101 --- pkg/concurrency/types.go | 6 - pkg/grpc/backend.go | 2 +- pkg/grpc/base/base.go | 4 +- pkg/grpc/client.go | 4 +- pkg/grpc/embed.go | 4 +- pkg/grpc/interface.go | 2 +- pkg/model/initializers.go | 8 +- pkg/startup/model_preload.go | 85 ++ .../startup}/model_preload_test.go | 5 +- pkg/utils/base64.go | 50 -- 52 files changed, 2295 insertions(+), 3065 deletions(-) delete mode 100644 core/config/backend_config_loader.go delete mode 100644 core/config/exports_test.go create mode 100644 core/http/endpoints/openai/inference.go create mode 100644 core/http/endpoints/openai/request.go rename core/schema/{transcription.go => whisper.go} (90%) delete mode 100644 core/services/list_models.go delete mode 100644 core/services/openai.go delete mode 100644 core/state.go delete mode 100644 examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru delete mode 100644 pkg/concurrency/concurrency.go delete mode 100644 pkg/concurrency/concurrency_test.go delete mode 100644 pkg/concurrency/types.go create mode 100644 pkg/startup/model_preload.go rename {core/services => pkg/startup}/model_preload_test.go (96%) delete mode 100644 pkg/utils/base64.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 156294b5..46c4e065 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -121,9 +121,8 @@ jobs: PATH="$PATH:/root/go/bin" GO_TAGS="stablediffusion tts" make --jobs 5 --output-sync=target test - name: Setup tmate session if tests fail if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3.18 - with: - connect-timeout-seconds: 180 + uses: mxschmitt/action-tmate@v3 + timeout-minutes: 5 tests-aio-container: runs-on: ubuntu-latest @@ -174,9 +173,8 @@ jobs: make run-e2e-aio - name: Setup tmate session if tests fail if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3.18 - with: - connect-timeout-seconds: 180 + uses: mxschmitt/action-tmate@v3 + timeout-minutes: 5 tests-apple: runs-on: macOS-14 @@ -209,6 +207,5 @@ jobs: BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make --jobs 4 --output-sync=target test - name: Setup tmate session if tests fail if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3.18 - with: - connect-timeout-seconds: 180 \ No newline at end of file + uses: mxschmitt/action-tmate@v3 + timeout-minutes: 5 \ No newline at end of file diff --git a/Makefile b/Makefile index fdc7aade..6715e91e 100644 --- a/Makefile +++ b/Makefile @@ -301,9 +301,6 @@ clean-tests: rm -rf test-dir rm -rf core/http/backend-assets -halt-backends: ## Used to clean up stray backends sometimes left running when debugging manually - ps | grep 'backend-assets/grpc/' | awk '{print $$1}' | xargs -I {} kill -9 {} - ## Build: build: prepare backend-assets grpcs ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) @@ -368,13 +365,13 @@ run-e2e-image: run-e2e-aio: @echo 'Running e2e AIO tests' - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e-aio test-e2e: @echo 'Running e2e tests' BUILD_TYPE=$(BUILD_TYPE) \ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390/v1 \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 -v -r ./tests/e2e teardown-e2e: rm -rf $(TEST_DIR) || true @@ -382,15 +379,15 @@ teardown-e2e: test-gpt4all: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r $(TEST_PATHS) test-llama: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r $(TEST_PATHS) test-llama-gguf: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts 5 -v -r $(TEST_PATHS) test-tts: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ @@ -648,10 +645,7 @@ backend-assets/grpc/llama-ggml: sources/go-llama-ggml sources/go-llama-ggml/libb $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(CURDIR)/sources/go-llama-ggml CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/sources/go-llama-ggml LIBRARY_PATH=$(CURDIR)/sources/go-llama-ggml \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama-ggml ./backend/go/llm/llama-ggml/ -# EXPERIMENTAL: -ifeq ($(BUILD_TYPE),metal) - cp $(CURDIR)/sources/go-llama-ggml/llama.cpp/ggml-metal.metal backend-assets/grpc/ -endif + backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/ diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go index b38d5b9f..fdfaa974 100644 --- a/backend/go/transcribe/transcript.go +++ b/backend/go/transcribe/transcript.go @@ -21,7 +21,7 @@ func runCommand(command []string) (string, error) { // AudioToWav converts audio to wav for transcribe. // TODO: use https://github.com/mccoyst/ogg? func audioToWav(src, dst string) error { - command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} + command := []string{"ffmpeg", "-i", src, "-format", "s16le", "-ar", "16000", "-ac", "1", "-acodec", "pcm_s16le", dst} out, err := runCommand(command) if err != nil { return fmt.Errorf("error: %w out: %s", err, out) @@ -29,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) { - res := schema.TranscriptionResult{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { + res := schema.Result{} dir, err := os.MkdirTemp("", "whisper") if err != nil { diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go index a9a62d24..ac93be01 100644 --- a/backend/go/transcribe/whisper.go +++ b/backend/go/transcribe/whisper.go @@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error { return err } -func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) { +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) } diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 2c63dedc..03ff90b9 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -2,100 +2,14 @@ package backend import ( "fmt" - "time" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/google/uuid" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" ) -type EmbeddingsBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig -} - -func NewEmbeddingsBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *EmbeddingsBackendService { - return &EmbeddingsBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, - } -} - -func (ebs *EmbeddingsBackendService) Embeddings(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] { - - resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - go func(request *schema.OpenAIRequest) { - if request.Model == "" { - request.Model = model.StableDiffusionBackend - } - - bc, request, err := ebs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, ebs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - items := []schema.Item{} - - for i, s := range bc.InputToken { - // get the model function to call for the result - embedFn, err := modelEmbedding("", s, ebs.ml, bc, ebs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - embeddings, err := embedFn() - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range bc.InputStrings { - // get the model function to call for the result - embedFn, err := modelEmbedding(s, []int{}, ebs.ml, bc, ebs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - embeddings, err := embedFn() - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - } - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp} - close(resultChannel) - }(request) - return resultChannel -} - -func modelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { modelFile := backendConfig.Model grpcOpts := gRPCModelOpts(backendConfig) diff --git a/core/backend/image.go b/core/backend/image.go index affb3bb3..b0cffb0b 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -1,252 +1,18 @@ package backend import ( - "bufio" - "encoding/base64" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "time" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/google/uuid" - "github.com/rs/zerolog/log" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" ) -type ImageGenerationBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig - BaseUrlForGeneratedImages string -} - -func NewImageGenerationBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ImageGenerationBackendService { - return &ImageGenerationBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, - } -} - -func (igbs *ImageGenerationBackendService) GenerateImage(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.OpenAIResponse] { - resultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - go func(request *schema.OpenAIRequest) { - bc, request, err := igbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, igbs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - src := "" - if request.File != "" { - - var fileData []byte - // check if input.File is an URL, if so download it and save it - // to a temporary file - if strings.HasPrefix(request.File, "http://") || strings.HasPrefix(request.File, "https://") { - out, err := downloadFile(request.File) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed downloading file:%w", err)} - close(resultChannel) - return - } - defer os.RemoveAll(out) - - fileData, err = os.ReadFile(out) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("failed reading file:%w", err)} - close(resultChannel) - return - } - - } else { - // base 64 decode the file and write it somewhere - // that we will cleanup - fileData, err = base64.StdEncoding.DecodeString(request.File) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - } - - // Create a temporary file - outputFile, err := os.CreateTemp(igbs.appConfig.ImageDir, "b64") - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - // write the base64 result - writer := bufio.NewWriter(outputFile) - _, err = writer.Write(fileData) - if err != nil { - outputFile.Close() - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - outputFile.Close() - src = outputFile.Name() - defer os.RemoveAll(src) - } - - log.Debug().Msgf("Parameter Config: %+v", bc) - - switch bc.Backend { - case "stablediffusion": - bc.Backend = model.StableDiffusionBackend - case "tinydream": - bc.Backend = model.TinyDreamBackend - case "": - bc.Backend = model.StableDiffusionBackend - if bc.Model == "" { - bc.Model = "stablediffusion_assets" // TODO: check? - } - } - - sizeParts := strings.Split(request.Size, "x") - if len(sizeParts) != 2 { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} - close(resultChannel) - return - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} - close(resultChannel) - return - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: fmt.Errorf("invalid value for 'size'")} - close(resultChannel) - return - } - - b64JSON := false - if request.ResponseFormat.Type == "b64_json" { - b64JSON = true - } - // src and clip_skip - var result []schema.Item - for _, i := range bc.PromptStrings { - n := request.N - if request.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := bc.Step - if step == 0 { - step = 15 - } - - if request.Mode != 0 { - mode = request.Mode - } - - if request.Step != 0 { - step = request.Step - } - - tempDir := "" - if !b64JSON { - tempDir = igbs.appConfig.ImageDir - } - // Create a temporary file - outputFile, err := os.CreateTemp(tempDir, "b64") - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - if request.Seed == nil { - zVal := 0 // Idiomatic way to do this? Actually needed? - request.Seed = &zVal - } - - fn, err := imageGeneration(height, width, mode, step, *request.Seed, positive_prompt, negative_prompt, src, output, igbs.ml, bc, igbs.appConfig) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - if err := fn(); err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - - item := &schema.Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: err} - close(resultChannel) - return - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = igbs.BaseUrlForGeneratedImages + base - } - - result = append(result, *item) - } - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Data: result, - } - resultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: resp} - close(resultChannel) - }(request) - return resultChannel -} - -func imageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { - +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { threads := backendConfig.Threads if *threads == 0 && appConfig.Threads != 0 { threads = &appConfig.Threads } - gRPCOpts := gRPCModelOpts(backendConfig) - opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(backendConfig.Backend), model.WithAssetDir(appConfig.AssetsDestination), @@ -284,24 +50,3 @@ func imageGeneration(height, width, mode, step, seed int, positive_prompt, negat return fn, nil } - -// TODO: Replace this function with pkg/downloader - no reason to have a (crappier) bespoke download file fn here, but get things working before that change. -func downloadFile(url string) (string, error) { - // Get the data - resp, err := http.Get(url) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // Create the file - out, err := os.CreateTemp("", "image") - if err != nil { - return "", err - } - defer out.Close() - - // Write the body to file - _, err = io.Copy(out, resp.Body) - return out.Name(), err -} diff --git a/core/backend/llm.go b/core/backend/llm.go index 75766d78..a4d1e5f3 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -11,22 +11,17 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/rs/zerolog/log" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) -type LLMRequest struct { - Id int // TODO Remove if not used. - Text string - Images []string - RawMessages []schema.Message - // TODO: Other Modalities? +type LLMResponse struct { + Response string // should this be []byte? + Usage TokenUsage } type TokenUsage struct { @@ -34,94 +29,57 @@ type TokenUsage struct { Completion int } -type LLMResponse struct { - Request *LLMRequest - Response string // should this be []byte? - Usage TokenUsage -} - -// TODO: Does this belong here or in core/services/openai.go? -type LLMResponseBundle struct { - Request *schema.OpenAIRequest - Response []schema.Choice - Usage TokenUsage -} - -type LLMBackendService struct { - bcl *config.BackendConfigLoader - ml *model.ModelLoader - appConfig *config.ApplicationConfig - ftMutex sync.Mutex - cutstrings map[string]*regexp.Regexp -} - -func NewLLMBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *LLMBackendService { - return &LLMBackendService{ - bcl: bcl, - ml: ml, - appConfig: appConfig, - ftMutex: sync.Mutex{}, - cutstrings: make(map[string]*regexp.Regexp), +func ModelInference(ctx context.Context, s string, messages []schema.Message, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { + modelFile := c.Model + threads := c.Threads + if *threads == 0 && o.Threads != 0 { + threads = &o.Threads } -} - -// TODO: Should ctx param be removed and replaced with hardcoded req.Context? -func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, bc *config.BackendConfig, enableTokenChannel bool) ( - resultChannel <-chan concurrency.ErrorOr[*LLMResponse], tokenChannel <-chan concurrency.ErrorOr[*LLMResponse], err error) { - - threads := bc.Threads - if (threads == nil || *threads == 0) && llmbs.appConfig.Threads != 0 { - threads = &llmbs.appConfig.Threads - } - - grpcOpts := gRPCModelOpts(bc) + grpcOpts := gRPCModelOpts(c) var inferenceModel grpc.Backend + var err error - opts := modelOpts(bc, llmbs.appConfig, []model.Option{ + opts := modelOpts(c, o, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup - model.WithAssetDir(llmbs.appConfig.AssetsDestination), - model.WithModel(bc.Model), - model.WithContext(llmbs.appConfig.Context), + model.WithAssetDir(o.AssetsDestination), + model.WithModel(modelFile), + model.WithContext(o.Context), }) - if bc.Backend != "" { - opts = append(opts, model.WithBackendString(bc.Backend)) + if c.Backend != "" { + opts = append(opts, model.WithBackendString(c.Backend)) } - // Check if bc.Model exists, if it doesn't try to load it from the gallery - if llmbs.appConfig.AutoloadGalleries { // experimental - if _, err := os.Stat(bc.Model); os.IsNotExist(err) { + // Check if the modelFile exists, if it doesn't try to load it from the gallery + if o.AutoloadGalleries { // experimental + if _, err := os.Stat(modelFile); os.IsNotExist(err) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGalleryByName(llmbs.appConfig.Galleries, bc.Model, llmbs.appConfig.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) if err != nil { - return nil, nil, err + return nil, err } } } - if bc.Backend == "" { - log.Debug().Msgf("backend not known for %q, falling back to greedy loader to find it", bc.Model) - inferenceModel, err = llmbs.ml.GreedyLoader(opts...) + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) } else { - inferenceModel, err = llmbs.ml.BackendLoader(opts...) + inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { - log.Error().Err(err).Msg("[llmbs.Inference] failed to load a backend") - return + return nil, err } - grpcPredOpts := gRPCPredictOpts(bc, llmbs.appConfig.ModelPath) - grpcPredOpts.Prompt = req.Text - grpcPredOpts.Images = req.Images - - if bc.TemplateConfig.UseTokenizerTemplate && req.Text == "" { - grpcPredOpts.UseTokenizerTemplate = true - protoMessages := make([]*proto.Message, len(req.RawMessages), len(req.RawMessages)) - for i, message := range req.RawMessages { + var protoMessages []*proto.Message + // if we are using the tokenizer template, we need to convert the messages to proto messages + // unless the prompt has already been tokenized (non-chat endpoints + functions) + if c.TemplateConfig.UseTokenizerTemplate && s == "" { + protoMessages = make([]*proto.Message, len(messages), len(messages)) + for i, message := range messages { protoMessages[i] = &proto.Message{ Role: message.Role, } @@ -129,32 +87,47 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, case string: protoMessages[i].Content = ct default: - err = fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct) - return + return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct) } } } - tokenUsage := TokenUsage{} + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + fn := func() (LLMResponse, error) { + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + opts.Messages = protoMessages + opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate + opts.Images = images - promptInfo, pErr := inferenceModel.TokenizeString(ctx, grpcPredOpts) - if pErr == nil && promptInfo.Length > 0 { - tokenUsage.Prompt = int(promptInfo.Length) - } + tokenUsage := TokenUsage{} - rawResultChannel := make(chan concurrency.ErrorOr[*LLMResponse]) - // TODO this next line is the biggest argument for taking named return values _back_ out!!! - var rawTokenChannel chan concurrency.ErrorOr[*LLMResponse] + // check the per-model feature flag for usage, since tokenCallback may have a cost. + // Defaults to off as for now it is still experimental + if c.FeatureFlag.Enabled("usage") { + userTokenCallback := tokenCallback + if userTokenCallback == nil { + userTokenCallback = func(token string, usage TokenUsage) bool { + return true + } + } - if enableTokenChannel { - rawTokenChannel = make(chan concurrency.ErrorOr[*LLMResponse]) + promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) + if pErr == nil && promptInfo.Length > 0 { + tokenUsage.Prompt = int(promptInfo.Length) + } - // TODO Needs better name - ss := "" + tokenCallback = func(token string, usage TokenUsage) bool { + tokenUsage.Completion++ + return userTokenCallback(token, tokenUsage) + } + } + + if tokenCallback != nil { + ss := "" - go func() { var partialRune []byte - err := inferenceModel.PredictStream(ctx, grpcPredOpts, func(chars []byte) { + err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { partialRune = append(partialRune, chars...) for len(partialRune) > 0 { @@ -164,126 +137,54 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, break } - tokenUsage.Completion++ - rawTokenChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ - Response: string(r), - Usage: tokenUsage, - }} - + tokenCallback(string(r), tokenUsage) ss += string(r) partialRune = partialRune[size:] } }) - close(rawTokenChannel) + return LLMResponse{ + Response: ss, + Usage: tokenUsage, + }, err + } else { + // TODO: Is the chicken bit the only way to get here? is that acceptable? + reply, err := inferenceModel.Predict(ctx, opts) if err != nil { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err} - } else { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ - Response: ss, - Usage: tokenUsage, - }} + return LLMResponse{}, err } - close(rawResultChannel) - }() - } else { - go func() { - reply, err := inferenceModel.Predict(ctx, grpcPredOpts) if tokenUsage.Prompt == 0 { tokenUsage.Prompt = int(reply.PromptTokens) } if tokenUsage.Completion == 0 { tokenUsage.Completion = int(reply.Tokens) } - if err != nil { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err} - close(rawResultChannel) - } else { - rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Value: &LLMResponse{ - Response: string(reply.Message), - Usage: tokenUsage, - }} - close(rawResultChannel) - } - }() + return LLMResponse{ + Response: string(reply.Message), + Usage: tokenUsage, + }, err + } } - resultChannel = rawResultChannel - tokenChannel = rawTokenChannel - return + return fn, nil } -// TODO: Should predInput be a seperate param still, or should this fn handle extracting it from request?? -func (llmbs *LLMBackendService) GenerateText(predInput string, request *schema.OpenAIRequest, bc *config.BackendConfig, - mappingFn func(*LLMResponse) schema.Choice, enableCompletionChannels bool, enableTokenChannels bool) ( - // Returns: - resultChannel <-chan concurrency.ErrorOr[*LLMResponseBundle], completionChannels []<-chan concurrency.ErrorOr[*LLMResponse], tokenChannels []<-chan concurrency.ErrorOr[*LLMResponse], err error) { +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} - rawChannel := make(chan concurrency.ErrorOr[*LLMResponseBundle]) - resultChannel = rawChannel - - if request.N == 0 { // number of completions to return - request.N = 1 - } - images := []string{} - for _, m := range request.Messages { - images = append(images, m.StringImages...) - } - - for i := 0; i < request.N; i++ { - - individualResultChannel, tokenChannel, infErr := llmbs.Inference(request.Context, &LLMRequest{ - Text: predInput, - Images: images, - RawMessages: request.Messages, - }, bc, enableTokenChannels) - if infErr != nil { - err = infErr // Avoids complaints about redeclaring err but looks dumb - return - } - completionChannels = append(completionChannels, individualResultChannel) - tokenChannels = append(tokenChannels, tokenChannel) - } - - go func() { - initialBundle := LLMResponseBundle{ - Request: request, - Response: []schema.Choice{}, - Usage: TokenUsage{}, - } - - wg := concurrency.SliceOfChannelsReducer(completionChannels, rawChannel, func(iv concurrency.ErrorOr[*LLMResponse], ov concurrency.ErrorOr[*LLMResponseBundle]) concurrency.ErrorOr[*LLMResponseBundle] { - if iv.Error != nil { - ov.Error = iv.Error - // TODO: Decide if we should wipe partials or not? - return ov - } - ov.Value.Usage.Prompt += iv.Value.Usage.Prompt - ov.Value.Usage.Completion += iv.Value.Usage.Completion - - ov.Value.Response = append(ov.Value.Response, mappingFn(iv.Value)) - return ov - }, concurrency.ErrorOr[*LLMResponseBundle]{Value: &initialBundle}, true) - wg.Wait() - - }() - - return -} - -func (llmbs *LLMBackendService) Finetune(config config.BackendConfig, input, prediction string) string { +func Finetune(config config.BackendConfig, input, prediction string) string { if config.Echo { prediction = input + prediction } for _, c := range config.Cutstrings { - llmbs.ftMutex.Lock() - reg, ok := llmbs.cutstrings[c] + mu.Lock() + reg, ok := cutstrings[c] if !ok { - llmbs.cutstrings[c] = regexp.MustCompile(c) - reg = llmbs.cutstrings[c] + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] } - llmbs.ftMutex.Unlock() + mu.Unlock() prediction = reg.ReplaceAllString(prediction, "") } diff --git a/core/backend/options.go b/core/backend/options.go index 0b4e56db..5b303b05 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -10,7 +10,7 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { +func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { if so.SingleBackend { opts = append(opts, model.WithSingleActiveBackend()) } @@ -19,12 +19,12 @@ func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []mo opts = append(opts, model.EnableParallelRequests) } - if bc.GRPC.Attempts != 0 { - opts = append(opts, model.WithGRPCAttempts(bc.GRPC.Attempts)) + if c.GRPC.Attempts != 0 { + opts = append(opts, model.WithGRPCAttempts(c.GRPC.Attempts)) } - if bc.GRPC.AttemptsSleepTime != 0 { - opts = append(opts, model.WithGRPCAttemptsDelay(bc.GRPC.AttemptsSleepTime)) + if c.GRPC.AttemptsSleepTime != 0 { + opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) } for k, v := range so.ExternalGRPCBackends { @@ -34,7 +34,7 @@ func modelOpts(bc *config.BackendConfig, so *config.ApplicationConfig, opts []mo return opts } -func getSeed(c *config.BackendConfig) int32 { +func getSeed(c config.BackendConfig) int32 { seed := int32(*c.Seed) if seed == config.RAND_SEED { seed = rand.Int31() @@ -43,7 +43,7 @@ func getSeed(c *config.BackendConfig) int32 { return seed } -func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions { +func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch @@ -104,47 +104,47 @@ func gRPCModelOpts(c *config.BackendConfig) *pb.ModelOptions { } } -func gRPCPredictOpts(bc *config.BackendConfig, modelPath string) *pb.PredictOptions { +func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions { promptCachePath := "" - if bc.PromptCachePath != "" { - p := filepath.Join(modelPath, bc.PromptCachePath) + if c.PromptCachePath != "" { + p := filepath.Join(modelPath, c.PromptCachePath) os.MkdirAll(filepath.Dir(p), 0755) promptCachePath = p } return &pb.PredictOptions{ - Temperature: float32(*bc.Temperature), - TopP: float32(*bc.TopP), - NDraft: bc.NDraft, - TopK: int32(*bc.TopK), - Tokens: int32(*bc.Maxtokens), - Threads: int32(*bc.Threads), - PromptCacheAll: bc.PromptCacheAll, - PromptCacheRO: bc.PromptCacheRO, + Temperature: float32(*c.Temperature), + TopP: float32(*c.TopP), + NDraft: c.NDraft, + TopK: int32(*c.TopK), + Tokens: int32(*c.Maxtokens), + Threads: int32(*c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, PromptCachePath: promptCachePath, - F16KV: *bc.F16, - DebugMode: *bc.Debug, - Grammar: bc.Grammar, - NegativePromptScale: bc.NegativePromptScale, - RopeFreqBase: bc.RopeFreqBase, - RopeFreqScale: bc.RopeFreqScale, - NegativePrompt: bc.NegativePrompt, - Mirostat: int32(*bc.LLMConfig.Mirostat), - MirostatETA: float32(*bc.LLMConfig.MirostatETA), - MirostatTAU: float32(*bc.LLMConfig.MirostatTAU), - Debug: *bc.Debug, - StopPrompts: bc.StopWords, - Repeat: int32(bc.RepeatPenalty), - NKeep: int32(bc.Keep), - Batch: int32(bc.Batch), - IgnoreEOS: bc.IgnoreEOS, - Seed: getSeed(bc), - FrequencyPenalty: float32(bc.FrequencyPenalty), - MLock: *bc.MMlock, - MMap: *bc.MMap, - MainGPU: bc.MainGPU, - TensorSplit: bc.TensorSplit, - TailFreeSamplingZ: float32(*bc.TFZ), - TypicalP: float32(*bc.TypicalP), + F16KV: *c.F16, + DebugMode: *c.Debug, + Grammar: c.Grammar, + NegativePromptScale: c.NegativePromptScale, + RopeFreqBase: c.RopeFreqBase, + RopeFreqScale: c.RopeFreqScale, + NegativePrompt: c.NegativePrompt, + Mirostat: int32(*c.LLMConfig.Mirostat), + MirostatETA: float32(*c.LLMConfig.MirostatETA), + MirostatTAU: float32(*c.LLMConfig.MirostatTAU), + Debug: *c.Debug, + StopPrompts: c.StopWords, + Repeat: int32(c.RepeatPenalty), + NKeep: int32(c.Keep), + Batch: int32(c.Batch), + IgnoreEOS: c.IgnoreEOS, + Seed: getSeed(c), + FrequencyPenalty: float32(c.FrequencyPenalty), + MLock: *c.MMlock, + MMap: *c.MMap, + MainGPU: c.MainGPU, + TensorSplit: c.TensorSplit, + TailFreeSamplingZ: float32(*c.TFZ), + TypicalP: float32(*c.TypicalP), } } diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 6761c2ac..4c3859df 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -7,48 +7,11 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" ) -type TranscriptionBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig -} - -func NewTranscriptionBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TranscriptionBackendService { - return &TranscriptionBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, - } -} - -func (tbs *TranscriptionBackendService) Transcribe(request *schema.OpenAIRequest) <-chan concurrency.ErrorOr[*schema.TranscriptionResult] { - responseChannel := make(chan concurrency.ErrorOr[*schema.TranscriptionResult]) - go func(request *schema.OpenAIRequest) { - bc, request, err := tbs.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, tbs.appConfig) - if err != nil { - responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: fmt.Errorf("failed reading parameters from request:%w", err)} - close(responseChannel) - return - } - - tr, err := modelTranscription(request.File, request.Language, tbs.ml, bc, tbs.appConfig) - if err != nil { - responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Error: err} - close(responseChannel) - return - } - responseChannel <- concurrency.ErrorOr[*schema.TranscriptionResult]{Value: tr} - close(responseChannel) - }(request) - return responseChannel -} - -func modelTranscription(audio, language string, ml *model.ModelLoader, backendConfig *config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { +func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) { opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(model.WhisperBackend), diff --git a/core/backend/tts.go b/core/backend/tts.go index d1fa270d..f97b6202 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -7,60 +7,29 @@ import ( "path/filepath" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/concurrency" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/model" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) -type TextToSpeechBackendService struct { - ml *model.ModelLoader - bcl *config.BackendConfigLoader - appConfig *config.ApplicationConfig -} +func generateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext -func NewTextToSpeechBackendService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *TextToSpeechBackendService { - return &TextToSpeechBackendService{ - ml: ml, - bcl: bcl, - appConfig: appConfig, + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName + } + + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) } } -func (ttsbs *TextToSpeechBackendService) TextToAudioFile(request *schema.TTSRequest) <-chan concurrency.ErrorOr[*string] { - responseChannel := make(chan concurrency.ErrorOr[*string]) - go func(request *schema.TTSRequest) { - cfg, err := ttsbs.bcl.LoadBackendConfigFileByName(request.Model, ttsbs.appConfig.ModelPath, - config.LoadOptionDebug(ttsbs.appConfig.Debug), - config.LoadOptionThreads(ttsbs.appConfig.Threads), - config.LoadOptionContextSize(ttsbs.appConfig.ContextSize), - config.LoadOptionF16(ttsbs.appConfig.F16), - ) - if err != nil { - responseChannel <- concurrency.ErrorOr[*string]{Error: err} - close(responseChannel) - return - } - - if request.Backend != "" { - cfg.Backend = request.Backend - } - - outFile, _, err := modelTTS(cfg.Backend, request.Input, cfg.Model, request.Voice, ttsbs.ml, ttsbs.appConfig, cfg) - if err != nil { - responseChannel <- concurrency.ErrorOr[*string]{Error: err} - close(responseChannel) - return - } - responseChannel <- concurrency.ErrorOr[*string]{Value: &outFile} - close(responseChannel) - }(request) - return responseChannel -} - -func modelTTS(backend, text, modelFile string, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig *config.BackendConfig) (string, *proto.Result, error) { +func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend @@ -68,7 +37,7 @@ func modelTTS(backend, text, modelFile string, voice string, loader *model.Model grpcOpts := gRPCModelOpts(backendConfig) - opts := modelOpts(&config.BackendConfig{}, appConfig, []model.Option{ + opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), model.WithContext(appConfig.Context), @@ -118,19 +87,3 @@ func modelTTS(backend, text, modelFile string, voice string, loader *model.Model return filePath, res, err } - -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext - - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) - } -} diff --git a/core/cli/run.go b/core/cli/run.go index cafc0b54..0f3ba2de 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -124,11 +124,11 @@ func (r *RunCMD) Run(ctx *Context) error { } if r.PreloadBackendOnly { - _, err := startup.Startup(opts...) + _, _, _, err := startup.Startup(opts...) return err } - application, err := startup.Startup(opts...) + cl, ml, options, err := startup.Startup(opts...) if err != nil { return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) @@ -137,7 +137,7 @@ func (r *RunCMD) Run(ctx *Context) error { // Watch the configuration directory // If the directory does not exist, we don't watch it if _, err := os.Stat(r.LocalaiConfigDir); err == nil { - closeConfigWatcherFn, err := startup.WatchConfigDirectory(r.LocalaiConfigDir, application.ApplicationConfig) + closeConfigWatcherFn, err := startup.WatchConfigDirectory(r.LocalaiConfigDir, options) defer closeConfigWatcherFn() if err != nil { @@ -145,7 +145,7 @@ func (r *RunCMD) Run(ctx *Context) error { } } - appHTTP, err := http.App(application) + appHTTP, err := http.App(cl, ml, options) if err != nil { log.Error().Err(err).Msg("error during HTTP App construction") return err diff --git a/core/cli/transcript.go b/core/cli/transcript.go index f14a1a87..9f36a77c 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -7,7 +7,6 @@ import ( "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" ) @@ -44,21 +43,11 @@ func (t *TranscriptCMD) Run(ctx *Context) error { defer ml.StopAllGRPC() - tbs := backend.NewTranscriptionBackendService(ml, cl, opts) - - resultChannel := tbs.Transcribe(&schema.OpenAIRequest{ - PredictionOptions: schema.PredictionOptions{ - Language: t.Language, - }, - File: t.Filename, - }) - - r := <-resultChannel - - if r.Error != nil { - return r.Error + tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts) + if err != nil { + return err } - for _, segment := range r.Value.Segments { + for _, segment := range tr.Segments { fmt.Println(segment.Start.String(), "-", segment.Text) } return nil diff --git a/core/cli/tts.go b/core/cli/tts.go index c7758c48..1d8fd3a3 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -9,7 +9,6 @@ import ( "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" ) @@ -43,29 +42,20 @@ func (t *TTSCMD) Run(ctx *Context) error { defer ml.StopAllGRPC() - ttsbs := backend.NewTextToSpeechBackendService(ml, config.NewBackendConfigLoader(), opts) + options := config.BackendConfig{} + options.SetDefaults() - request := &schema.TTSRequest{ - Model: t.Model, - Input: text, - Backend: t.Backend, - Voice: t.Voice, - } - - resultsChannel := ttsbs.TextToAudioFile(request) - - rawResult := <-resultsChannel - - if rawResult.Error != nil { - return rawResult.Error + filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options) + if err != nil { + return err } if outputFile != "" { - if err := os.Rename(*rawResult.Value, outputFile); err != nil { + if err := os.Rename(filePath, outputFile); err != nil { return err } - fmt.Printf("Generated file %q\n", outputFile) + fmt.Printf("Generate file %s\n", outputFile) } else { - fmt.Printf("Generated file %q\n", *rawResult.Value) + fmt.Printf("Generate file %s\n", filePath) } return nil } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 47e4829d..81c92d01 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -1,7 +1,22 @@ package config import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" + "gopkg.in/yaml.v3" + + "github.com/charmbracelet/glamour" ) const ( @@ -184,7 +199,7 @@ func (c *BackendConfig) FunctionToCall() string { } func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { - lo := &ConfigLoaderOptions{} + lo := &LoadOptions{} lo.Apply(opts...) ctx := lo.ctxSize @@ -297,3 +312,287 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.Debug = &trueV } } + +////// Config Loader //////// + +type BackendConfigLoader struct { + configs map[string]BackendConfig + sync.Mutex +} + +type LoadOptions struct { + debug bool + threads, ctxSize int + f16 bool +} + +func LoadOptionDebug(debug bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.debug = debug + } +} + +func LoadOptionThreads(threads int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.threads = threads + } +} + +func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.ctxSize = ctxSize + } +} + +func LoadOptionF16(f16 bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.f16 = f16 + } +} + +type ConfigLoaderOption func(*LoadOptions) + +func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { + for _, l := range options { + l(lo) + } +} + +// Load a config file for a model +func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + + // Load a config file if present after the model name + cfg := &BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Model: modelName, + }, + } + + cfgExisting, exists := cl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } else { + // Try loading a model config file + modelConfig := filepath.Join(modelPath, modelName+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := cl.LoadBackendConfig( + modelConfig, opts..., + ); err != nil { + return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } + } + } + + cfg.SetDefaults(opts...) + + return cfg, nil +} + +func NewBackendConfigLoader() *BackendConfigLoader { + return &BackendConfigLoader{ + configs: make(map[string]BackendConfig), + } +} +func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { + c := &[]*BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + for _, cc := range *c { + cc.SetDefaults(opts...) + } + + return *c, nil +} + +func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + lo := &LoadOptions{} + lo.Apply(opts...) + + c := &BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + c.SetDefaults(opts...) + return c, nil +} + +func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadBackendConfigFile(file, opts...) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cm.configs[cc.Name] = *cc + } + return nil +} + +func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { + cl.Lock() + defer cl.Unlock() + c, err := ReadBackendConfig(file, opts...) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cl.configs[c.Name] = *c + return nil +} + +func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { + cl.Lock() + defer cl.Unlock() + v, exists := cl.configs[m] + return v, exists +} + +func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { + cl.Lock() + defer cl.Unlock() + var res []BackendConfig + for _, v := range cl.configs { + res = append(res, v) + } + + sort.SliceStable(res, func(i, j int) bool { + return res[i].Name < res[j].Name + }) + + return res +} + +func (cl *BackendConfigLoader) ListBackendConfigs() []string { + cl.Lock() + defer cl.Unlock() + var res []string + for k := range cl.configs { + res = append(res, k) + } + return res +} + +// Preload prepare models if they are not local but url or huggingface repositories +func (cl *BackendConfigLoader) Preload(modelPath string) error { + cl.Lock() + defer cl.Unlock() + + status := func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + } + + log.Info().Msgf("Preloading models from %s", modelPath) + + renderMode := "dark" + if os.Getenv("COLOR") != "" { + renderMode = os.Getenv("COLOR") + } + + glamText := func(t string) { + out, err := glamour.Render(t, renderMode) + if err == nil && os.Getenv("NO_COLOR") == "" { + fmt.Println(out) + } else { + fmt.Println(t) + } + } + + for i, config := range cl.configs { + + // Download files and verify their SHA + for _, file := range config.DownloadFiles { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + if err := utils.VerifyPath(file.Filename, modelPath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(modelPath, file.Filename) + + if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { + return err + } + } + + modelURL := config.PredictionOptions.Model + modelURL = downloader.ConvertURL(modelURL) + + if downloader.LooksLikeURL(modelURL) { + // md5 of model name + md5Name := utils.MD5(modelURL) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) + if err != nil { + return err + } + } + + cc := cl.configs[i] + c := &cc + c.PredictionOptions.Model = md5Name + cl.configs[i] = *c + } + if cl.configs[i].Name != "" { + glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) + } + if cl.configs[i].Description != "" { + //glamText("**Description**") + glamText(cl.configs[i].Description) + } + if cl.configs[i].Usage != "" { + //glamText("**Usage**") + glamText(cl.configs[i].Usage) + } + } + return nil +} + +// LoadBackendConfigsFromPath reads all the configurations of the models from a path +// (non-recursive) +func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { + continue + } + c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) + if err == nil { + cm.configs[c.Name] = *c + } + } + + return nil +} diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go deleted file mode 100644 index 62dfc1e0..00000000 --- a/core/config/backend_config_loader.go +++ /dev/null @@ -1,509 +0,0 @@ -package config - -import ( - "encoding/json" - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - "sort" - "strings" - "sync" - - "github.com/charmbracelet/glamour" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/downloader" - "github.com/go-skynet/LocalAI/pkg/grammar" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" - "gopkg.in/yaml.v2" -) - -type BackendConfigLoader struct { - configs map[string]BackendConfig - sync.Mutex -} - -type ConfigLoaderOptions struct { - debug bool - threads, ctxSize int - f16 bool -} - -func LoadOptionDebug(debug bool) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.debug = debug - } -} - -func LoadOptionThreads(threads int) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.threads = threads - } -} - -func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.ctxSize = ctxSize - } -} - -func LoadOptionF16(f16 bool) ConfigLoaderOption { - return func(o *ConfigLoaderOptions) { - o.f16 = f16 - } -} - -type ConfigLoaderOption func(*ConfigLoaderOptions) - -func (lo *ConfigLoaderOptions) Apply(options ...ConfigLoaderOption) { - for _, l := range options { - l(lo) - } -} - -func NewBackendConfigLoader() *BackendConfigLoader { - return &BackendConfigLoader{ - configs: make(map[string]BackendConfig), - } -} - -func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { - bcl.Lock() - defer bcl.Unlock() - c, err := readBackendConfig(file, opts...) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - bcl.configs[c.Name] = *c - return nil -} - -func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { - bcl.Lock() - defer bcl.Unlock() - v, exists := bcl.configs[m] - return v, exists -} - -func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { - bcl.Lock() - defer bcl.Unlock() - var res []BackendConfig - for _, v := range bcl.configs { - res = append(res, v) - } - sort.SliceStable(res, func(i, j int) bool { - return res[i].Name < res[j].Name - }) - return res -} - -func (bcl *BackendConfigLoader) ListBackendConfigs() []string { - bcl.Lock() - defer bcl.Unlock() - var res []string - for k := range bcl.configs { - res = append(res, k) - } - return res -} - -// Preload prepare models if they are not local but url or huggingface repositories -func (bcl *BackendConfigLoader) Preload(modelPath string) error { - bcl.Lock() - defer bcl.Unlock() - - status := func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - } - - log.Info().Msgf("Preloading models from %s", modelPath) - - renderMode := "dark" - if os.Getenv("COLOR") != "" { - renderMode = os.Getenv("COLOR") - } - - glamText := func(t string) { - out, err := glamour.Render(t, renderMode) - if err == nil && os.Getenv("NO_COLOR") == "" { - fmt.Println(out) - } else { - fmt.Println(t) - } - } - - for i, config := range bcl.configs { - - // Download files and verify their SHA - for _, file := range config.DownloadFiles { - log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) - - if err := utils.VerifyPath(file.Filename, modelPath); err != nil { - return err - } - // Create file path - filePath := filepath.Join(modelPath, file.Filename) - - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { - return err - } - } - - modelURL := config.PredictionOptions.Model - modelURL = downloader.ConvertURL(modelURL) - - if downloader.LooksLikeURL(modelURL) { - // md5 of model name - md5Name := utils.MD5(modelURL) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) - if err != nil { - return err - } - } - - cc := bcl.configs[i] - c := &cc - c.PredictionOptions.Model = md5Name - bcl.configs[i] = *c - } - if bcl.configs[i].Name != "" { - glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name)) - } - if bcl.configs[i].Description != "" { - //glamText("**Description**") - glamText(bcl.configs[i].Description) - } - if bcl.configs[i].Usage != "" { - //glamText("**Usage**") - glamText(bcl.configs[i].Usage) - } - } - return nil -} - -func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { - bcl.Lock() - defer bcl.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { - continue - } - c, err := readBackendConfig(filepath.Join(path, file.Name()), opts...) - if err == nil { - bcl.configs[c.Name] = *c - } - } - - return nil -} - -func (bcl *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { - bcl.Lock() - defer bcl.Unlock() - c, err := readBackendConfigFile(file, opts...) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - bcl.configs[cc.Name] = *cc - } - return nil -} - -////////// - -// Load a config file for a model -func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName string, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - - // Load a config file if present after the model name - cfg := &BackendConfig{ - PredictionOptions: schema.PredictionOptions{ - Model: modelName, - }, - } - - cfgExisting, exists := bcl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } else { - // Load a config file if present after the model name - modelConfig := filepath.Join(modelPath, modelName+".yaml") - if _, err := os.Stat(modelConfig); err == nil { - if err := bcl.LoadBackendConfig(modelConfig); err != nil { - return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfgExisting, exists = bcl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } - } - } - - cfg.SetDefaults(opts...) - return cfg, nil -} - -func readBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { - c := &[]*BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - for _, cc := range *c { - cc.SetDefaults(opts...) - } - - return *c, nil -} - -func readBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - c := &BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - c.SetDefaults(opts...) - return c, nil -} - -func (bcl *BackendConfigLoader) LoadBackendConfigForModelAndOpenAIRequest(modelFile string, input *schema.OpenAIRequest, appConfig *ApplicationConfig) (*BackendConfig, *schema.OpenAIRequest, error) { - cfg, err := bcl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - LoadOptionContextSize(appConfig.ContextSize), - LoadOptionDebug(appConfig.Debug), - LoadOptionF16(appConfig.F16), - LoadOptionThreads(appConfig.Threads), - ) - - // Set the parameters for the language model prediction - updateBackendConfigFromOpenAIRequest(cfg, input) - - return cfg, input, err -} - -func updateBackendConfigFromOpenAIRequest(bc *BackendConfig, request *schema.OpenAIRequest) { - if request.Echo { - bc.Echo = request.Echo - } - if request.TopK != nil && *request.TopK != 0 { - bc.TopK = request.TopK - } - if request.TopP != nil && *request.TopP != 0 { - bc.TopP = request.TopP - } - - if request.Backend != "" { - bc.Backend = request.Backend - } - - if request.ClipSkip != 0 { - bc.Diffusers.ClipSkip = request.ClipSkip - } - - if request.ModelBaseName != "" { - bc.AutoGPTQ.ModelBaseName = request.ModelBaseName - } - - if request.NegativePromptScale != 0 { - bc.NegativePromptScale = request.NegativePromptScale - } - - if request.UseFastTokenizer { - bc.UseFastTokenizer = request.UseFastTokenizer - } - - if request.NegativePrompt != "" { - bc.NegativePrompt = request.NegativePrompt - } - - if request.RopeFreqBase != 0 { - bc.RopeFreqBase = request.RopeFreqBase - } - - if request.RopeFreqScale != 0 { - bc.RopeFreqScale = request.RopeFreqScale - } - - if request.Grammar != "" { - bc.Grammar = request.Grammar - } - - if request.Temperature != nil && *request.Temperature != 0 { - bc.Temperature = request.Temperature - } - - if request.Maxtokens != nil && *request.Maxtokens != 0 { - bc.Maxtokens = request.Maxtokens - } - - switch stop := request.Stop.(type) { - case string: - if stop != "" { - bc.StopWords = append(bc.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - bc.StopWords = append(bc.StopWords, s) - } - } - } - - if len(request.Tools) > 0 { - for _, tool := range request.Tools { - request.Functions = append(request.Functions, tool.Function) - } - } - - if request.ToolsChoice != nil { - var toolChoice grammar.Tool - switch content := request.ToolsChoice.(type) { - case string: - _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]interface{}: - dat, _ := json.Marshal(content) - _ = json.Unmarshal(dat, &toolChoice) - } - request.FunctionCall = map[string]interface{}{ - "name": toolChoice.Function.Name, - } - } - - // Decode each request's message content - index := 0 - for i, m := range request.Messages { - switch content := m.Content.(type) { - case string: - request.Messages[i].StringContent = content - case []interface{}: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - for _, pp := range c { - if pp.Type == "text" { - request.Messages[i].StringContent = pp.Text - } else if pp.Type == "image_url" { - // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: - base64, err := utils.GetImageURLAsBase64(pp.ImageURL.URL) - if err == nil { - request.Messages[i].StringImages = append(request.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - // set a placeholder for each image - request.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + request.Messages[i].StringContent - index++ - } else { - fmt.Print("Failed encoding image", err) - } - } - } - } - } - - if request.RepeatPenalty != 0 { - bc.RepeatPenalty = request.RepeatPenalty - } - - if request.FrequencyPenalty != 0 { - bc.FrequencyPenalty = request.FrequencyPenalty - } - - if request.PresencePenalty != 0 { - bc.PresencePenalty = request.PresencePenalty - } - - if request.Keep != 0 { - bc.Keep = request.Keep - } - - if request.Batch != 0 { - bc.Batch = request.Batch - } - - if request.IgnoreEOS { - bc.IgnoreEOS = request.IgnoreEOS - } - - if request.Seed != nil { - bc.Seed = request.Seed - } - - if request.TypicalP != nil { - bc.TypicalP = request.TypicalP - } - - switch inputs := request.Input.(type) { - case string: - if inputs != "" { - bc.InputStrings = append(bc.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - bc.InputStrings = append(bc.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - bc.InputToken = append(bc.InputToken, tokens) - } - } - } - - // Can be either a string or an object - switch fnc := request.FunctionCall.(type) { - case string: - if fnc != "" { - bc.SetFunctionCallString(fnc) - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - bc.SetFunctionCallNameString(name) - } - - switch p := request.Prompt.(type) { - case string: - bc.PromptStrings = append(bc.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - bc.PromptStrings = append(bc.PromptStrings, s) - } - } - } -} diff --git a/core/config/exports_test.go b/core/config/exports_test.go deleted file mode 100644 index 70ba84e6..00000000 --- a/core/config/exports_test.go +++ /dev/null @@ -1,6 +0,0 @@ -package config - -// This file re-exports private functions to be used directly in unit tests. -// Since this file's name ends in _test.go, theoretically these should not be exposed past the tests. - -var ReadBackendConfigFile = readBackendConfigFile diff --git a/core/http/api.go b/core/http/api.go index 7094899a..af38512a 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -1,20 +1,23 @@ package http import ( + "encoding/json" "errors" + "os" "strings" - "github.com/go-skynet/LocalAI/core" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/swagger" // swagger handler "github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs" "github.com/go-skynet/LocalAI/core/http/endpoints/localai" "github.com/go-skynet/LocalAI/core/http/endpoints/openai" + + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/internal" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" @@ -52,12 +55,13 @@ func readAuthHeader(c *fiber.Ctx) string { // @securityDefinitions.apikey BearerAuth // @in header // @name Authorization -func App(application *core.Application) (*fiber.App, error) { + +func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) { // Return errors as JSON responses app := fiber.New(fiber.Config{ Views: renderEngine(), - BodyLimit: application.ApplicationConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: application.ApplicationConfig.DisableMessage, + BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: appConfig.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -78,7 +82,7 @@ func App(application *core.Application) (*fiber.App, error) { }, }) - if application.ApplicationConfig.Debug { + if appConfig.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) @@ -86,7 +90,7 @@ func App(application *core.Application) (*fiber.App, error) { // Default middleware config - if !application.ApplicationConfig.Debug { + if !appConfig.Debug { app.Use(recover.New()) } @@ -104,7 +108,25 @@ func App(application *core.Application) (*fiber.App, error) { // Auth middleware checking if API key is valid. If no API key is set, no auth is required. auth := func(c *fiber.Ctx) error { - if len(application.ApplicationConfig.ApiKeys) == 0 { + if len(appConfig.ApiKeys) == 0 { + return c.Next() + } + + // Check for api_keys.json file + fileContent, err := os.ReadFile("api_keys.json") + if err == nil { + // Parse JSON content from the file + var fileKeys []string + err := json.Unmarshal(fileContent, &fileKeys) + if err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) + } + + // Add file keys to options.ApiKeys + appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) + } + + if len(appConfig.ApiKeys) == 0 { return c.Next() } @@ -120,7 +142,7 @@ func App(application *core.Application) (*fiber.App, error) { } apiKey := authHeaderParts[1] - for _, key := range application.ApplicationConfig.ApiKeys { + for _, key := range appConfig.ApiKeys { if apiKey == key { return c.Next() } @@ -129,22 +151,20 @@ func App(application *core.Application) (*fiber.App, error) { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) } - if application.ApplicationConfig.CORS { + if appConfig.CORS { var c func(ctx *fiber.Ctx) error - if application.ApplicationConfig.CORSAllowOrigins == "" { + if appConfig.CORSAllowOrigins == "" { c = cors.New() } else { - c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig.CORSAllowOrigins}) + c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins}) } app.Use(c) } - fiberContextExtractor := fiberContext.NewFiberContextExtractor(application.ModelLoader, application.ApplicationConfig) - // LocalAI API endpoints - galleryService := services.NewGalleryService(application.ApplicationConfig.ModelPath) - galleryService.Start(application.ApplicationConfig.Context, application.BackendConfigLoader) + galleryService := services.NewGalleryService(appConfig.ModelPath) + galleryService.Start(appConfig.Context, cl) app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { @@ -152,17 +172,29 @@ func App(application *core.Application) (*fiber.App, error) { }{Version: internal.PrintableVersion()}) }) + // Make sure directories exists + os.MkdirAll(appConfig.ImageDir, 0755) + os.MkdirAll(appConfig.AudioDir, 0755) + os.MkdirAll(appConfig.UploadDir, 0755) + os.MkdirAll(appConfig.ConfigsDir, 0755) + os.MkdirAll(appConfig.ModelPath, 0755) + + // Load config jsons + utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) + utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) + utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) + app.Get("/swagger/*", swagger.HandlerDefault) // default welcomeRoute( app, - application.BackendConfigLoader, - application.ModelLoader, - application.ApplicationConfig, + cl, + ml, + appConfig, auth, ) - modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(application.ApplicationConfig.Galleries, application.ApplicationConfig.ModelPath, galleryService) + modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) @@ -171,85 +203,83 @@ func App(application *core.Application) (*fiber.App, error) { app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint()) app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint()) - // Stores - storeLoader := model.NewModelLoader("") // TODO: Investigate if this should be migrated to application and reused. Should the path be configurable? Merging for now. - app.Post("/stores/set", auth, localai.StoresSetEndpoint(storeLoader, application.ApplicationConfig)) - app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(storeLoader, application.ApplicationConfig)) - app.Post("/stores/get", auth, localai.StoresGetEndpoint(storeLoader, application.ApplicationConfig)) - app.Post("/stores/find", auth, localai.StoresFindEndpoint(storeLoader, application.ApplicationConfig)) - - // openAI compatible API endpoints - - // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(fiberContextExtractor, application.OpenAIService)) - - // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/edits", auth, openai.EditEndpoint(fiberContextExtractor, application.OpenAIService)) - - // assistant - // TODO: Refactor this to the new style eventually - app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants", auth, openai.ListAssistantsEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/assistants", auth, openai.CreateAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(application.BackendConfigLoader, application.ModelLoader, application.ApplicationConfig)) - - // files - app.Post("/v1/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Post("/files", auth, openai.UploadFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/v1/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/files", auth, openai.ListFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(application.BackendConfigLoader, application.ApplicationConfig)) - - // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(fiberContextExtractor, application.OpenAIService)) - - // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(fiberContextExtractor, application.EmbeddingsBackendService)) - - // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(fiberContextExtractor, application.TranscriptionBackendService)) - app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) - - // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(fiberContextExtractor, application.ImageGenerationBackendService)) + app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig)) // Elevenlabs - app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) + app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig)) - // LocalAI TTS? - app.Post("/tts", auth, localai.TTSEndpoint(fiberContextExtractor, application.TextToSpeechBackendService)) + // Stores + sl := model.NewModelLoader("") + app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig)) + app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig)) + app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig)) + app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig)) - if application.ApplicationConfig.ImageDir != "" { - app.Static("/generated-images", application.ApplicationConfig.ImageDir) + // openAI compatible API endpoint + + // chat + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + + // edit + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + + // assistant + app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig)) + app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig)) + + // files + app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + + // completion + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + + // embeddings + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + + // audio + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig)) + + // images + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig)) + + if appConfig.ImageDir != "" { + app.Static("/generated-images", appConfig.ImageDir) } - if application.ApplicationConfig.AudioDir != "" { - app.Static("/generated-audio", application.ApplicationConfig.AudioDir) + if appConfig.AudioDir != "" { + app.Static("/generated-audio", appConfig.AudioDir) } ok := func(c *fiber.Ctx) error { @@ -261,12 +291,13 @@ func App(application *core.Application) (*fiber.App, error) { app.Get("/readyz", ok) // Experimental Backend Statistics Module - app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(application.BackendMonitorService)) - app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(application.BackendMonitorService)) + backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now + app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitor)) + app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitor)) // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(application.ListModelsService)) - app.Get("/models", auth, openai.ListModelsEndpoint(application.ListModelsService)) + app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) + app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint()) diff --git a/core/http/api_test.go b/core/http/api_test.go index bf8feb1c..1553ed21 100644 --- a/core/http/api_test.go +++ b/core/http/api_test.go @@ -12,9 +12,7 @@ import ( "os" "path/filepath" "runtime" - "strings" - "github.com/go-skynet/LocalAI/core" "github.com/go-skynet/LocalAI/core/config" . "github.com/go-skynet/LocalAI/core/http" "github.com/go-skynet/LocalAI/core/schema" @@ -207,7 +205,9 @@ var _ = Describe("API test", func() { var cancel context.CancelFunc var tmpdir string var modelDir string - var application *core.Application + var bcl *config.BackendConfigLoader + var ml *model.ModelLoader + var applicationConfig *config.ApplicationConfig commonOpts := []config.AppOption{ config.WithDebug(true), @@ -252,7 +252,7 @@ var _ = Describe("API test", func() { }, } - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithGalleries(galleries), @@ -261,7 +261,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(backendAssetsDir))...) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -474,11 +474,11 @@ var _ = Describe("API test", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp2.Choices)).To(Equal(1)) - Expect(resp2.Choices[0].Message.ToolCalls[0].Function).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name) + Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) var res map[string]string - err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res) + err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) @@ -487,9 +487,9 @@ var _ = Describe("API test", func() { }) It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() { - // if runtime.GOOS != "linux" { - // Skip("test supported only on linux") - // } + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } modelName := "codellama" response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "github:go-skynet/model-gallery/codellama-7b-instruct.yaml", @@ -504,7 +504,7 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) return response["processed"].(bool) - }, "480s", "10s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) By("testing chat") resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{ @@ -551,13 +551,11 @@ var _ = Describe("API test", func() { }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp2.Choices)).To(Equal(1)) - fmt.Printf("\n--- %+v\n\n", resp2.Choices[0].Message) - Expect(resp2.Choices[0].Message.ToolCalls).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.ToolCalls[0]).ToNot(BeNil()) - Expect(resp2.Choices[0].Message.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.ToolCalls[0].Function.Name) + Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) var res map[string]string - err = json.Unmarshal([]byte(resp2.Choices[0].Message.ToolCalls[0].Function.Arguments), &res) + err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) @@ -611,7 +609,7 @@ var _ = Describe("API test", func() { }, } - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithAudioDir(tmpdir), @@ -622,7 +620,7 @@ var _ = Describe("API test", func() { config.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -726,14 +724,14 @@ var _ = Describe("API test", func() { var err error - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), config.WithContext(c), config.WithModelPath(modelPath), )...) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -763,11 +761,6 @@ var _ = Describe("API test", func() { Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8? }) It("can generate completions via ggml", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -775,11 +768,6 @@ var _ = Describe("API test", func() { }) It("can generate chat completions via ggml", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -787,11 +775,6 @@ var _ = Describe("API test", func() { }) It("can generate completions from model configs", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: testPrompt}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -799,11 +782,6 @@ var _ = Describe("API test", func() { }) It("can generate chat completions from model configs", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) @@ -890,9 +868,9 @@ var _ = Describe("API test", func() { Context("backends", func() { It("runs rwkv completion", func() { - // if runtime.GOOS != "linux" { - // Skip("test supported only on linux") - // } + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices) > 0).To(BeTrue()) @@ -913,20 +891,17 @@ var _ = Describe("API test", func() { } Expect(err).ToNot(HaveOccurred()) - - if len(response.Choices) > 0 { - text += response.Choices[0].Text - tokens++ - } + text += response.Choices[0].Text + tokens++ } Expect(text).ToNot(BeEmpty()) Expect(text).To(ContainSubstring("five")) Expect(tokens).ToNot(Or(Equal(1), Equal(0))) }) It("runs rwkv chat completion", func() { - // if runtime.GOOS != "linux" { - // Skip("test supported only on linux") - // } + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) Expect(err).ToNot(HaveOccurred()) @@ -1035,14 +1010,14 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - application, err = startup.Startup( + bcl, ml, applicationConfig, err = startup.Startup( append(commonOpts, config.WithContext(c), config.WithModelPath(modelPath), config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) - app, err = App(application) + app, err = App(bcl, ml, applicationConfig) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -1066,33 +1041,18 @@ var _ = Describe("API test", func() { } }) It("can generate chat completions from config file (list1)", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate chat completions from config file (list2)", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate edit completions from config file", func() { - bt, ok := os.LookupEnv("BUILD_TYPE") - if ok && strings.ToLower(bt) == "metal" { - Skip("GGML + Metal is known flaky, skip test temporarily") - } - request := openaigo.EditCreateRequestBody{ Model: "list2", Instruction: "foo", diff --git a/core/http/ctx/fiber.go b/core/http/ctx/fiber.go index 99fbcde9..ffb63111 100644 --- a/core/http/ctx/fiber.go +++ b/core/http/ctx/fiber.go @@ -1,88 +1,43 @@ package fiberContext import ( - "context" - "encoding/json" "fmt" "strings" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) -type FiberContextExtractor struct { - ml *model.ModelLoader - appConfig *config.ApplicationConfig -} - -func NewFiberContextExtractor(ml *model.ModelLoader, appConfig *config.ApplicationConfig) *FiberContextExtractor { - return &FiberContextExtractor{ - ml: ml, - appConfig: appConfig, - } -} - // ModelFromContext returns the model from the context // If no model is specified, it will take the first available // Takes a model string as input which should be the one received from the user request. // It returns the model name resolved from the context and an error if any. -func (fce *FiberContextExtractor) ModelFromContext(ctx *fiber.Ctx, modelInput string, firstModel bool) (string, error) { - ctxPM := ctx.Params("model") - if ctxPM != "" { - log.Debug().Msgf("[FCE] Overriding param modelInput %q with ctx.Params value %q", modelInput, ctxPM) - modelInput = ctxPM +func ModelFromContext(ctx *fiber.Ctx, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) { + if ctx.Params("model") != "" { + modelInput = ctx.Params("model") } // Set model from bearer token, if available - bearer := strings.TrimPrefix(ctx.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && fce.ml.ExistsInModelPath(bearer) + bearer := strings.TrimLeft(ctx.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) // If no model was specified, take the first available if modelInput == "" && !bearerExists && firstModel { - models, _ := fce.ml.ListModels() + models, _ := loader.ListModels() if len(models) > 0 { modelInput = models[0] - log.Debug().Msgf("[FCE] No model specified, using first available: %s", modelInput) + log.Debug().Msgf("No model specified, using: %s", modelInput) } else { - log.Warn().Msgf("[FCE] No model specified, none available") - return "", fmt.Errorf("[fce] no model specified, none available") + log.Debug().Msgf("No model specified, returning error") + return "", fmt.Errorf("no model specified") } } // If a model is found in bearer token takes precedence if bearerExists { - log.Debug().Msgf("[FCE] Using model from bearer token: %s", bearer) + log.Debug().Msgf("Using model from bearer token: %s", bearer) modelInput = bearer } - - if modelInput == "" { - log.Warn().Msg("[FCE] modelInput is empty") - } return modelInput, nil } - -// TODO: Do we still need the first return value? -func (fce *FiberContextExtractor) OpenAIRequestFromContext(c *fiber.Ctx, firstModel bool) (string, *schema.OpenAIRequest, error) { - input := new(schema.OpenAIRequest) - - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, fmt.Errorf("failed parsing request body: %w", err) - } - - received, _ := json.Marshal(input) - - ctx, cancel := context.WithCancel(fce.appConfig.Context) - input.Context = ctx - input.Cancel = cancel - - log.Debug().Msgf("Request received: %s", string(received)) - - var err error - input.Model, err = fce.ModelFromContext(c, input.Model, firstModel) - - return input.Model, input, err -} diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 4f5db463..841f9b5f 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -2,7 +2,9 @@ package elevenlabs import ( "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/core/schema" "github.com/gofiber/fiber/v2" @@ -15,7 +17,7 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/text-to-speech/{voice-id} [post] -func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.ElevenLabsTTSRequest) @@ -26,21 +28,34 @@ func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToS return err } - var err error - input.ModelID, err = fce.ModelFromContext(c, input.ModelID, false) + modelFile, err := fiberContext.ModelFromContext(c, ml, input.ModelID, false) if err != nil { + modelFile = input.ModelID log.Warn().Msgf("Model not found in context: %s", input.ModelID) } - responseChannel := ttsbs.TextToAudioFile(&schema.TTSRequest{ - Model: input.ModelID, - Voice: voiceID, - Input: input.Text, - }) - rawValue := <-responseChannel - if rawValue.Error != nil { - return rawValue.Error + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + if err != nil { + modelFile = input.ModelID + log.Warn().Msgf("Model not found in context: %s", input.ModelID) + } else { + if input.ModelID != "" { + modelFile = input.ModelID + } else { + modelFile = cfg.Model + } } - return c.Download(*rawValue.Value) + log.Debug().Msgf("Request for model: %s", modelFile) + + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg) + if err != nil { + return err + } + return c.Download(filePath) } } diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index dac20388..8c7a664a 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -6,7 +6,7 @@ import ( "github.com/gofiber/fiber/v2" ) -func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { +func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) @@ -23,7 +23,7 @@ func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ct } } -func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { +func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) // Get input data from the request body diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index df7841fb..7822e024 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -2,7 +2,9 @@ package localai import ( "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/core/schema" "github.com/gofiber/fiber/v2" @@ -14,26 +16,45 @@ import ( // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/audio/speech [post] -func TTSEndpoint(fce *fiberContext.FiberContextExtractor, ttsbs *backend.TextToSpeechBackendService) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - var err error + input := new(schema.TTSRequest) // Get input data from the request body - if err = c.BodyParser(input); err != nil { + if err := c.BodyParser(input); err != nil { return err } - input.Model, err = fce.ModelFromContext(c, input.Model, false) + modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false) if err != nil { + modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } - responseChannel := ttsbs.TextToAudioFile(input) - rawValue := <-responseChannel - if rawValue.Error != nil { - return rawValue.Error + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + + if err != nil { + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } else { + modelFile = cfg.Model } - return c.Download(*rawValue.Value) + log.Debug().Msgf("Request for model: %s", modelFile) + + if input.Backend != "" { + cfg.Backend = input.Backend + } + + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg) + if err != nil { + return err + } + return c.Download(filePath) } } diff --git a/core/http/endpoints/openai/assistant.go b/core/http/endpoints/openai/assistant.go index 72cb8b4a..dceb3789 100644 --- a/core/http/endpoints/openai/assistant.go +++ b/core/http/endpoints/openai/assistant.go @@ -339,7 +339,7 @@ func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.Model } } - return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistantID %q", assistantID)) + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find ")) } } diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index a240b024..36d1142b 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -5,11 +5,17 @@ import ( "bytes" "encoding/json" "fmt" + "strings" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -19,82 +25,412 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] -func ChatEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { +func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { + emptyMessage := "" + id := uuid.New().String() + created := int(time.Now().Unix()) + + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + + responses <- resp + return true + }) + close(responses) + } + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + result := "" + _, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + result += s + // TODO: Change generated BNF grammar to be compliant with the schema so we can + // stream the result token by token here. + return true + }) + + results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls) + noActionToRun := len(results) > 0 && results[0].name == noAction + + switch { + case noActionToRun: + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt) + if err != nil { + log.Error().Err(err).Msg("error handling question") + return + } + + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + + responses <- resp + + default: + for i, ss := range results { + name, args := ss.name, ss.arguments + + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + responses <- schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Arguments: args, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + } + } + + close(responses) + } + return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + processFunctions := false + funcs := grammar.Functions{} + modelFile, input, err := readRequest(c, ml, startupOptions, true) if err != nil { - return fmt.Errorf("failed reading parameters from request: %w", err) + return fmt.Errorf("failed reading parameters from request:%w", err) } - traceID, finalResultChannel, _, tokenChannel, err := oais.Chat(request, false, request.Stream) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16) if err != nil { - return err + return fmt.Errorf("failed reading parameters from request:%w", err) + } + log.Debug().Msgf("Configuration read: %+v", config) + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName } - if request.Stream { + if input.ResponseFormat.Type == "json_object" { + input.Grammar = grammar.JSONBNF + } - log.Debug().Msgf("Chat Stream request received") + config.Grammar = input.Grammar + // process functions if we have any defined or if we have a function call string + if len(input.Functions) > 0 && config.ShouldUseFunctions() { + log.Debug().Msgf("Response needs to process functions") + + processFunctions = true + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, input.Functions...) + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) + } else if input.JSONFunctionGrammarObject != nil { + config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) + } + + // functions are not supported in stream mode (yet?) + toStream := input.Stream + + log.Debug().Msgf("Parameters: %+v", config) + + var predInput string + + // If we are using the tokenizer template, we don't need to process the messages + // unless we are processing functions + if !config.TemplateConfig.UseTokenizerTemplate || processFunctions { + + suppressConfigSystemPrompt := false + mess := []string{} + for messageIndex, i := range input.Messages { + var content string + role := i.Role + + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && i.StringContent != "" + + fcall := i.FunctionCall + if len(i.ToolCalls) > 0 { + fcall = i.ToolCalls + } + + // First attempt to populate content via a chat message specific template + if config.TemplateConfig.ChatMessage != "" { + chatMessageData := model.ChatMessageTemplateData{ + SystemPrompt: config.SystemPrompt, + Role: r, + RoleName: role, + Content: i.StringContent, + FunctionCall: fcall, + FunctionName: i.Name, + LastMessage: messageIndex == (len(input.Messages) - 1), + Function: config.Grammar != "" && (messageIndex == (len(input.Messages) - 1)), + MessageIndex: messageIndex, + } + templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + if err != nil { + log.Error().Err(err).Interface("message", chatMessageData).Str("template", config.TemplateConfig.ChatMessage).Msg("error processing message with template, skipping") + } else { + if templatedChatMessage == "" { + log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) + continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + } + log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) + content = templatedChatMessage + } + } + + marshalAnyRole := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + marshalAny := func(f any) { + j, err := json.Marshal(f) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. + if content == "" { + if r != "" { + if contentExists { + content = fmt.Sprint(r, i.StringContent) + } + + if i.FunctionCall != nil { + marshalAnyRole(i.FunctionCall) + } + if i.ToolCalls != nil { + marshalAnyRole(i.ToolCalls) + } + } else { + if contentExists { + content = fmt.Sprint(i.StringContent) + } + if i.FunctionCall != nil { + marshalAny(i.FunctionCall) + } + if i.ToolCalls != nil { + marshalAny(i.ToolCalls) + } + } + // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately + if contentExists && role == "system" { + suppressConfigSystemPrompt = true + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Chat != "" && !processFunctions { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + + if templateFile != "" { + templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + SuppressSystemPrompt: suppressConfigSystemPrompt, + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + } + + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + } + + switch { + case toStream: + + log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // + // c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") + responses := make(chan schema.OpenAIResponse) + + if !processFunctions { + go process(predInput, input, config, ml, responses) + } else { + go processTools(noActionName, predInput, input, config, ml, responses) + } + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { usage := &schema.OpenAIUsage{} toolsCalled := false - for ev := range tokenChannel { - if ev.Error != nil { - log.Debug().Err(ev.Error).Msg("chat streaming responseChannel error") - request.Cancel() - break - } - usage = &ev.Value.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - - if len(ev.Value.Choices[0].Delta.ToolCalls) > 0 { + for ev := range responses { + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices[0].Delta.ToolCalls) > 0 { toolsCalled = true } var buf bytes.Buffer enc := json.NewEncoder(&buf) - if ev.Error != nil { - log.Debug().Err(ev.Error).Msg("[ChatEndpoint] error to debug during tokenChannel handler") - enc.Encode(ev.Error) - } else { - enc.Encode(ev.Value) - } - log.Debug().Msgf("chat streaming sending chunk: %s", buf.String()) + enc.Encode(ev) + log.Debug().Msgf("Sending chunk: %s", buf.String()) _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) if err != nil { - log.Debug().Err(err).Msgf("Sending chunk failed") - request.Cancel() - break - } - err = w.Flush() - if err != nil { - log.Debug().Msg("error while flushing, closing connection") - request.Cancel() + log.Debug().Msgf("Sending chunk failed: %v", err) + input.Cancel() break } + w.Flush() } finishReason := "stop" if toolsCalled { finishReason = "tool_calls" - } else if toolsCalled && len(request.Tools) == 0 { + } else if toolsCalled && len(input.Tools) == 0 { finishReason = "function_call" } resp := &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { FinishReason: finishReason, Index: 0, - Delta: &schema.Message{Content: ""}, + Delta: &schema.Message{Content: &emptyMessage}, }}, Object: "chat.completion.chunk", Usage: *usage, @@ -105,21 +441,202 @@ func ChatEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAI w.WriteString("data: [DONE]\n\n") w.Flush() })) - return nil + + // no streaming mode + default: + result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) { + if !processFunctions { + // no function is called, just reply and use stop as finish reason + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + return + } + + results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls) + noActionsToRun := len(results) > 0 && results[0].name == noActionName + + switch { + case noActionsToRun: + result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput) + if err != nil { + log.Error().Err(err).Msg("error handling question") + return + } + *c = append(*c, schema.Choice{ + Message: &schema.Message{Role: "assistant", Content: &result}}) + default: + toolChoice := schema.Choice{ + Message: &schema.Message{ + Role: "assistant", + }, + } + + if len(input.Tools) > 0 { + toolChoice.FinishReason = "tool_calls" + } + + for _, ss := range results { + name, args := ss.name, ss.arguments + if len(input.Tools) > 0 { + // If we are using tools, we condense the function calls into + // a single response choice with all the tools + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + ) + } else { + // otherwise we return more choices directly + *c = append(*c, schema.Choice{ + FinishReason: "function_call", + Message: &schema.Message{ + Role: "assistant", + FunctionCall: map[string]interface{}{ + "name": name, + "arguments": args, + }, + }, + }) + } + } + + if len(input.Tools) > 0 { + // we need to append our result if we are using tools + *c = append(*c, toolChoice) + } + } + + }, nil) + if err != nil { + return err + } + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) } - // TODO is this proper to have exclusive from Stream, or do we need to issue both responses? - rawResponse := <-finalResultChannel - - if rawResponse.Error != nil { - return rawResponse.Error - } - - jsonResult, _ := json.Marshal(rawResponse.Value) - log.Debug().Str("jsonResult", string(jsonResult)).Msg("Chat Final Response") - - // Return the prediction in the response body - return c.JSON(rawResponse.Value) } } + +func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(args), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, prompt, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + return message, nil + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU/GPU) another computation + config.Grammar = "" + images := []string{} + for _, m := range input.Messages { + images = append(images, m.StringImages...) + } + + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, ml, *config, o, nil) + if err != nil { + log.Error().Err(err).Msg("model inference failed") + return "", err + } + + prediction, err := predFunc() + if err != nil { + log.Error().Err(err).Msg("prediction failed") + return "", err + } + return backend.Finetune(*config, prompt, prediction.Response), nil +} + +type funcCallResults struct { + name string + arguments string +} + +func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { + results := []funcCallResults{} + + // TODO: use generics to avoid this code duplication + if multipleResults { + ss := []map[string]interface{}{} + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + for _, s := range ss { + func_name, ok := s["function"] + if !ok { + continue + } + args, ok := s["arguments"] + if !ok { + continue + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + continue + } + results = append(results, funcCallResults{name: funcName, arguments: string(d)}) + } + } else { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name, ok := ss["function"] + if !ok { + return results + } + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + if !ok { + return results + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + return results + } + results = append(results, funcCallResults{name: funcName, arguments: string(d)}) + } + + return results +} diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index d8b412a9..69923475 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -4,13 +4,18 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -20,50 +25,116 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] -func CompletionEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + id := uuid.New().String() + created := int(time.Now().Unix()) + + process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{ + { + Index: 0, + Text: s, + }, + }, + Object: "text_completion", + Usage: schema.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - log.Debug().Msgf("`OpenAIRequest`: %+v", request) + log.Debug().Msgf("`input`: %+v", input) - traceID, finalResultChannel, _, _, tokenChannel, err := oais.Completion(request, false, request.Stream) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { - return err + return fmt.Errorf("failed reading parameters from request:%w", err) } - if request.Stream { - log.Debug().Msgf("Completion Stream request received") + if input.ResponseFormat.Type == "json_object" { + input.Grammar = grammar.JSONBNF + } + config.Grammar = input.Grammar + + log.Debug().Msgf("Parameter Config: %+v", config) + + if input.Stream { + log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) //c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") + } + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Completion != "" { + templateFile = config.TemplateConfig.Completion + } + + if input.Stream { + if len(config.PromptStrings) > 1 { + return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") + } + + predInput := config.PromptStrings[0] + + if templateFile != "" { + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + } + + responses := make(chan schema.OpenAIResponse) + + go process(predInput, input, config, ml, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - for ev := range tokenChannel { + + for ev := range responses { var buf bytes.Buffer enc := json.NewEncoder(&buf) - if ev.Error != nil { - log.Debug().Msgf("[CompletionEndpoint] error to debug during tokenChannel handler: %q", ev.Error) - enc.Encode(ev.Error) - } else { - enc.Encode(ev.Value) - } + enc.Encode(ev) - log.Debug().Msgf("completion streaming sending chunk: %s", buf.String()) + log.Debug().Msgf("Sending chunk: %s", buf.String()) fmt.Fprintf(w, "data: %v\n", buf.String()) w.Flush() } resp := &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, @@ -80,15 +151,55 @@ func CompletionEndpoint(fce *fiberContext.FiberContextExtractor, oais *services. })) return nil } - // TODO is this proper to have exclusive from Stream, or do we need to issue both responses? - rawResponse := <-finalResultChannel - if rawResponse.Error != nil { - return rawResponse.Error + + var result []schema.Choice + + totalTokenUsage := backend.TokenUsage{} + + for k, i := range config.PromptStrings { + if templateFile != "" { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + } + + r, tokenUsage, err := ComputeChoices( + input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) + }, nil) + if err != nil { + return err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) } - jsonResult, _ := json.Marshal(rawResponse.Value) + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + Usage: schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, + } + + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index a33050dd..25497095 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -3,36 +3,92 @@ package openai import ( "encoding/json" "fmt" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" - "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/rs/zerolog/log" ) -func EditEndpoint(fce *fiberContext.FiberContextExtractor, oais *services.OpenAIService) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - _, finalResultChannel, _, _, _, err := oais.Edit(request, false, request.Stream) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { - return err + return fmt.Errorf("failed reading parameters from request:%w", err) } - rawResponse := <-finalResultChannel - if rawResponse.Error != nil { - return rawResponse.Error + log.Debug().Msgf("Parameter Config: %+v", config) + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model } - jsonResult, _ := json.Marshal(rawResponse.Value) + if config.TemplateConfig.Edit != "" { + templateFile = config.TemplateConfig.Edit + } + + var result []schema.Choice + totalTokenUsage := backend.TokenUsage{} + + for _, i := range config.InputStrings { + if templateFile != "" { + templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ + Input: i, + Instruction: input.Instruction, + SystemPrompt: config.SystemPrompt, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + } + + r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + Usage: schema.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, + } + + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index be546991..eca34f79 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -3,9 +3,14 @@ package openai import ( "encoding/json" "fmt" + "time" "github.com/go-skynet/LocalAI/core/backend" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/model" + + "github.com/go-skynet/LocalAI/core/schema" + "github.com/google/uuid" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -16,25 +21,63 @@ import ( // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/embeddings [post] -func EmbeddingsEndpoint(fce *fiberContext.FiberContextExtractor, ebs *backend.EmbeddingsBackendService) func(c *fiber.Ctx) error { +func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - _, input, err := fce.OpenAIRequestFromContext(c, true) + model, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - responseChannel := ebs.Embeddings(input) - - rawResponse := <-responseChannel - - if rawResponse.Error != nil { - return rawResponse.Error + config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) } - jsonResult, _ := json.Marshal(rawResponse.Value) + log.Debug().Msgf("Parameter Config: %+v", config) + items := []schema.Item{} + + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", + } + + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index ec3d84da..9e806b3e 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -1,18 +1,50 @@ package openai import ( + "bufio" + "encoding/base64" "encoding/json" "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/google/uuid" "github.com/go-skynet/LocalAI/core/backend" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) -// https://platform.openai.com/docs/api-reference/images/create +func downloadFile(url string) (string, error) { + // Get the data + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Create the file + out, err := os.CreateTemp("", "image") + if err != nil { + return "", err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + return out.Name(), err +} + +// /* * @@ -27,36 +59,186 @@ import ( * */ - // ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create // @Summary Creates an image given a prompt. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/images/generations [post] -func ImageEndpoint(fce *fiberContext.FiberContextExtractor, igbs *backend.ImageGenerationBackendService) func(c *fiber.Ctx) error { +func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - // TODO: Somewhat a hack. Is there a better place to assign this? - if igbs.BaseUrlForGeneratedImages == "" { - igbs.BaseUrlForGeneratedImages = c.BaseURL() + "/generated-images/" - } - _, request, err := fce.OpenAIRequestFromContext(c, false) + m, input, err := readRequest(c, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - responseChannel := igbs.GenerateImage(request) - rawResponse := <-responseChannel - - if rawResponse.Error != nil { - return rawResponse.Error + if m == "" { + m = model.StableDiffusionBackend } + log.Debug().Msgf("Loading model: %+v", m) - jsonResult, err := json.Marshal(rawResponse.Value) + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false) if err != nil { - return err + return fmt.Errorf("failed reading parameters from request:%w", err) } + + src := "" + if input.File != "" { + + fileData := []byte{} + // check if input.File is an URL, if so download it and save it + // to a temporary file + if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { + out, err := downloadFile(input.File) + if err != nil { + return fmt.Errorf("failed downloading file:%w", err) + } + defer os.RemoveAll(out) + + fileData, err = os.ReadFile(out) + if err != nil { + return fmt.Errorf("failed reading file:%w", err) + } + + } else { + // base 64 decode the file and write it somewhere + // that we will cleanup + fileData, err = base64.StdEncoding.DecodeString(input.File) + if err != nil { + return err + } + } + + // Create a temporary file + outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") + if err != nil { + return err + } + // write the base64 result + writer := bufio.NewWriter(outputFile) + _, err = writer.Write(fileData) + if err != nil { + outputFile.Close() + return err + } + outputFile.Close() + src = outputFile.Name() + defer os.RemoveAll(src) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + switch config.Backend { + case "stablediffusion": + config.Backend = model.StableDiffusionBackend + case "tinydream": + config.Backend = model.TinyDreamBackend + case "": + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return fmt.Errorf("invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + return fmt.Errorf("invalid value for 'size'") + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return fmt.Errorf("invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat.Type == "b64_json" { + b64JSON = true + } + // src and clip_skip + var result []schema.Item + for _, i := range config.PromptStrings { + n := input.N + if input.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := config.Step + if step == 0 { + step = 15 + } + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = appConfig.ImageDir + } + // Create a temporary file + outputFile, err := os.CreateTemp(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &schema.Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-images/" + base + } + + result = append(result, *item) + } + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Data: result, + } + + jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) + // Return the prediction in the response body - return c.JSON(rawResponse.Value) + return c.JSON(resp) } } diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go new file mode 100644 index 00000000..06e784b7 --- /dev/null +++ b/core/http/endpoints/openai/inference.go @@ -0,0 +1,55 @@ +package openai + +import ( + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" + + "github.com/go-skynet/LocalAI/core/schema" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ComputeChoices( + req *schema.OpenAIRequest, + predInput string, + config *config.BackendConfig, + o *config.ApplicationConfig, + loader *model.ModelLoader, + cb func(string, *[]schema.Choice), + tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { + n := req.N // number of completions to return + result := []schema.Choice{} + + if n == 0 { + n = 1 + } + + images := []string{} + for _, m := range req.Messages { + images = append(images, m.StringImages...) + } + + // get the model function to call for the result + predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, loader, *config, o, tokenCallback) + if err != nil { + return result, backend.TokenUsage{}, err + } + + tokenUsage := backend.TokenUsage{} + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, backend.TokenUsage{}, err + } + + tokenUsage.Prompt += prediction.Usage.Prompt + tokenUsage.Completion += prediction.Usage.Completion + + finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) + cb(finetunedResponse, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, tokenUsage, err +} diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 9bb2b2ca..04e611a2 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -1,21 +1,61 @@ package openai import ( + "regexp" + + "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/core/services" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) -func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { - // If blank, no filter is applied. + models, err := ml.ListModels() + if err != nil { + return err + } + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []schema.OpenAIModel{} + + var filterFn func(name string) bool filter := c.Query("filter") + + // If filter is not specified, do not filter the list by model name + if filter == "" { + filterFn = func(_ string) bool { return true } + } else { + // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn + rxp, err := regexp.Compile(filter) + if err != nil { + return err + } + filterFn = func(name string) bool { + return rxp.MatchString(name) + } + } + // By default, exclude any loose files that are already referenced by a configuration file. excludeConfigured := c.QueryBool("excludeConfigured", true) - dataModels, err := lms.ListModels(filter, excludeConfigured) - if err != nil { - return err + // Start with the known configurations + for _, c := range cl.GetAllBackendConfigs() { + if excludeConfigured { + mm[c.Model] = nil + } + + if filterFn(c.Name) { + dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) + } + } + + // Then iterate through the loose files: + for _, m := range models { + // And only adds them if they shouldn't be skipped. + if _, exists := mm[m]; !exists && filterFn(m) { + dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) + } } return c.JSON(struct { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go new file mode 100644 index 00000000..369fb0b8 --- /dev/null +++ b/core/http/endpoints/openai/request.go @@ -0,0 +1,285 @@ +package openai + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/go-skynet/LocalAI/core/config" + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { + input := new(schema.OpenAIRequest) + + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, fmt.Errorf("failed parsing request body: %w", err) + } + + received, _ := json.Marshal(input) + + ctx, cancel := context.WithCancel(o.Context) + input.Context = ctx + input.Cancel = cancel + + log.Debug().Msgf("Request received: %s", string(received)) + + modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel) + + return modelFile, input, err +} + +// this function check if the string is an URL, if it's an URL downloads the image in memory +// encodes it in base64 and returns the base64 string +func getBase64Image(s string) (string, error) { + if strings.HasPrefix(s, "http") { + // download the image + resp, err := http.Get(s) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // read the image data into memory + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // encode the image data in base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // return the base64 string + return encoded, nil + } + + // if the string instead is prefixed with "data:image/jpeg;base64,", drop it + if strings.HasPrefix(s, "data:image/jpeg;base64,") { + return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + } + return "", fmt.Errorf("not valid string") +} + +func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != nil { + config.TopK = input.TopK + } + if input.TopP != nil { + config.TopP = input.TopP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.ModelBaseName != "" { + config.AutoGPTQ.ModelBaseName = input.ModelBaseName + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.UseFastTokenizer { + config.UseFastTokenizer = input.UseFastTokenizer + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + if input.Maxtokens != nil { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + var toolChoice grammar.Tool + + switch content := input.ToolsChoice.(type) { + case string: + _ = json.Unmarshal([]byte(content), &toolChoice) + case map[string]interface{}: + dat, _ := json.Marshal(content) + _ = json.Unmarshal(dat, &toolChoice) + } + input.FunctionCall = map[string]interface{}{ + "name": toolChoice.Function.Name, + } + } + + // Decode each request's message content + index := 0 + for i, m := range input.Messages { + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []interface{}: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + for _, pp := range c { + if pp.Type == "text" { + input.Messages[i].StringContent = pp.Text + } else if pp.Type == "image_url" { + // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: + base64, err := getBase64Image(pp.ImageURL.URL) + if err == nil { + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + // set a placeholder for each image + input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent + index++ + } else { + fmt.Print("Failed encoding image", err) + } + } + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.FrequencyPenalty != 0 { + config.FrequencyPenalty = input.FrequencyPenalty + } + + if input.PresencePenalty != 0 { + config.PresencePenalty = input.PresencePenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != nil { + config.Seed = input.Seed + } + + if input.TypicalP != nil { + config.TypicalP = input.TypicalP + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } +} + +func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { + cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, + config.LoadOptionDebug(debug), + config.LoadOptionThreads(threads), + config.LoadOptionContextSize(ctx), + config.LoadOptionF16(f16), + ) + + // Set the parameters for the language model prediction + updateRequestConfig(cfg, input) + + return cfg, input, err +} diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 572cec12..c7dd39e7 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -9,7 +9,8 @@ import ( "path/filepath" "github.com/go-skynet/LocalAI/core/backend" - fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/config" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -22,15 +23,17 @@ import ( // @Param file formData file true "file" // @Success 200 {object} map[string]string "Response" // @Router /v1/audio/transcriptions [post] -func TranscriptEndpoint(fce *fiberContext.FiberContextExtractor, tbs *backend.TranscriptionBackendService) func(c *fiber.Ctx) error { +func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - _, request, err := fce.OpenAIRequestFromContext(c, false) + m, input, err := readRequest(c, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - // TODO: Investigate this file copy stuff later - potentially belongs in service. - + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { @@ -62,16 +65,13 @@ func TranscriptEndpoint(fce *fiberContext.FiberContextExtractor, tbs *backend.Tr log.Debug().Msgf("Audio file copied to: %+v", dst) - request.File = dst - - responseChannel := tbs.Transcribe(request) - rawResponse := <-responseChannel - - if rawResponse.Error != nil { - return rawResponse.Error + tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig) + if err != nil { + return err } - log.Debug().Msgf("Transcribed: %+v", rawResponse.Value) + + log.Debug().Msgf("Trascribed: %+v", tr) // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(rawResponse.Value) + return c.Status(http.StatusOK).JSON(tr) } } diff --git a/core/schema/transcription.go b/core/schema/whisper.go similarity index 90% rename from core/schema/transcription.go rename to core/schema/whisper.go index fe1799fa..41413c1f 100644 --- a/core/schema/transcription.go +++ b/core/schema/whisper.go @@ -10,7 +10,7 @@ type Segment struct { Tokens []int `json:"tokens"` } -type TranscriptionResult struct { +type Result struct { Segments []Segment `json:"segments"` Text string `json:"text"` } diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go index a610432c..979a67a3 100644 --- a/core/services/backend_monitor.go +++ b/core/services/backend_monitor.go @@ -15,22 +15,22 @@ import ( gopsutil "github.com/shirou/gopsutil/v3/process" ) -type BackendMonitorService struct { +type BackendMonitor struct { configLoader *config.BackendConfigLoader modelLoader *model.ModelLoader options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. } -func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService { - return &BackendMonitorService{ +func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor { + return BackendMonitor{ configLoader: configLoader, modelLoader: modelLoader, options: appConfig, } } -func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) { - config, exists := bms.configLoader.GetBackendConfig(modelName) +func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { + config, exists := bm.configLoader.GetBackendConfig(modelName) var backendId string if exists { backendId = config.Model @@ -46,8 +46,8 @@ func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) return backendId, nil } -func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { - config, exists := bms.configLoader.GetBackendConfig(model) +func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { + config, exists := bm.configLoader.GetBackendConfig(model) var backend string if exists { backend = config.Model @@ -60,7 +60,7 @@ func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*sche backend = fmt.Sprintf("%s.bin", backend) } - pid, err := bms.modelLoader.GetGRPCPID(backend) + pid, err := bm.modelLoader.GetGRPCPID(backend) if err != nil { log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid") @@ -101,12 +101,12 @@ func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*sche }, nil } -func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) { - backendId, err := bms.getModelLoaderIDFromModelName(modelName) +func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) if err != nil { return nil, err } - modelAddr := bms.modelLoader.CheckIsLoaded(backendId) + modelAddr := bm.modelLoader.CheckIsLoaded(backendId) if modelAddr == "" { return nil, fmt.Errorf("backend %s is not currently loaded", backendId) } @@ -114,7 +114,7 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) if rpcErr != nil { log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) - val, slbErr := bms.SampleLocalBackendProcess(backendId) + val, slbErr := bm.SampleLocalBackendProcess(backendId) if slbErr != nil { return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) } @@ -131,10 +131,10 @@ func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.Status return status, nil } -func (bms BackendMonitorService) ShutdownModel(modelName string) error { - backendId, err := bms.getModelLoaderIDFromModelName(modelName) +func (bm BackendMonitor) ShutdownModel(modelName string) error { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) if err != nil { return err } - return bms.modelLoader.ShutdownModel(backendId) + return bm.modelLoader.ShutdownModel(backendId) } diff --git a/core/services/gallery.go b/core/services/gallery.go index 1ef8e3e2..b068abbb 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -3,18 +3,14 @@ package services import ( "context" "encoding/json" - "errors" "os" - "path/filepath" "strings" "sync" "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/embedded" - "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) @@ -33,6 +29,18 @@ func NewGalleryService(modelPath string) *GalleryService { } } +func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error { + + config, err := gallery.GetGalleryConfigFromURL(req.URL) + if err != nil { + return err + } + + config.Files = append(config.Files, req.AdditionalFiles...) + + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) +} + func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) { g.Lock() defer g.Unlock() @@ -84,10 +92,10 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) } } else if op.ConfigURL != "" { - PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) + startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) err = cl.Preload(g.modelPath) } else { - err = prepareModel(g.modelPath, op.Req, progressCallback) + err = prepareModel(g.modelPath, op.Req, cl, progressCallback) } if err != nil { @@ -119,12 +127,13 @@ type galleryModel struct { ID string `json:"id"` } -func processRequests(modelPath string, galleries []gallery.Gallery, requests []galleryModel) error { +func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { var err error for _, r := range requests { utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) + } else { if strings.Contains(r.ID, "@") { err = gallery.InstallModelFromGallery( @@ -149,7 +158,7 @@ func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, g return err } - return processRequests(modelPath, galleries, requests) + return processRequests(modelPath, s, cl, galleries, requests) } func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error { @@ -159,90 +168,5 @@ func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, return err } - return processRequests(modelPath, galleries, requests) -} - -// PreloadModelsConfigurations will preload models from the given list of URLs -// It will download the model if it is not already present in the model path -// It will also try to resolve if the model is an embedded model YAML configuration -func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { - for _, url := range models { - - // As a best effort, try to resolve the model from the remote library - // if it's not resolved we try with the other method below - if modelLibraryURL != "" { - lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) - if err == nil { - if lib[url] != "" { - log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) - url = lib[url] - } - } - } - - url = embedded.ModelShortURL(url) - switch { - case embedded.ExistsInModelsLibrary(url): - modelYAML, err := embedded.ResolveContent(url) - // If we resolve something, just save it to disk and continue - if err != nil { - log.Error().Err(err).Msg("error resolving model content") - continue - } - - log.Debug().Msgf("[startup] resolved embedded model: %s", url) - md5Name := utils.MD5(url) - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") - } - case downloader.LooksLikeURL(url): - log.Debug().Msgf("[startup] resolved model to download: %s", url) - - // md5 of model name - md5Name := utils.MD5(url) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - }) - if err != nil { - log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") - } - } - default: - if _, err := os.Stat(url); err == nil { - log.Debug().Msgf("[startup] resolved local model: %s", url) - // copy to modelPath - md5Name := utils.MD5(url) - - modelYAML, err := os.ReadFile(url) - if err != nil { - log.Error().Err(err).Str("filepath", url).Msg("error reading model definition") - continue - } - - modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") - } - } else { - log.Warn().Msgf("[startup] failed resolving model '%s'", url) - } - } - } -} - -func prepareModel(modelPath string, req gallery.GalleryModel, downloadStatus func(string, string, string, float64)) error { - - config, err := gallery.GetGalleryConfigFromURL(req.URL) - if err != nil { - return err - } - - config.Files = append(config.Files, req.AdditionalFiles...) - - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) + return processRequests(modelPath, s, cl, galleries, requests) } diff --git a/core/services/list_models.go b/core/services/list_models.go deleted file mode 100644 index a21e6faf..00000000 --- a/core/services/list_models.go +++ /dev/null @@ -1,72 +0,0 @@ -package services - -import ( - "regexp" - - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/model" -) - -type ListModelsService struct { - bcl *config.BackendConfigLoader - ml *model.ModelLoader - appConfig *config.ApplicationConfig -} - -func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService { - return &ListModelsService{ - bcl: bcl, - ml: ml, - appConfig: appConfig, - } -} - -func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) { - - models, err := lms.ml.ListModels() - if err != nil { - return nil, err - } - - var mm map[string]interface{} = map[string]interface{}{} - - dataModels := []schema.OpenAIModel{} - - var filterFn func(name string) bool - - // If filter is not specified, do not filter the list by model name - if filter == "" { - filterFn = func(_ string) bool { return true } - } else { - // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn - rxp, err := regexp.Compile(filter) - if err != nil { - return nil, err - } - filterFn = func(name string) bool { - return rxp.MatchString(name) - } - } - - // Start with the known configurations - for _, c := range lms.bcl.GetAllBackendConfigs() { - if excludeConfigured { - mm[c.Model] = nil - } - - if filterFn(c.Name) { - dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) - } - } - - // Then iterate through the loose files: - for _, m := range models { - // And only adds them if they shouldn't be skipped. - if _, exists := mm[m]; !exists && filterFn(m) { - dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) - } - } - - return dataModels, nil -} diff --git a/core/services/openai.go b/core/services/openai.go deleted file mode 100644 index 7a2679ad..00000000 --- a/core/services/openai.go +++ /dev/null @@ -1,808 +0,0 @@ -package services - -import ( - "encoding/json" - "errors" - "fmt" - "strings" - "sync" - "time" - - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/pkg/concurrency" - "github.com/go-skynet/LocalAI/pkg/grammar" - "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/google/uuid" - "github.com/imdario/mergo" - "github.com/rs/zerolog/log" -) - -type endpointGenerationConfigurationFn func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration - -type endpointConfiguration struct { - SchemaObject string - TemplatePath string - TemplateData model.PromptTemplateData - ResultMappingFn func(resp *backend.LLMResponse, index int) schema.Choice - CompletionMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] - TokenMappingFn func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] -} - -// TODO: This is used for completion and edit. I am pretty sure I forgot parts, but fix it later. -func simpleMapper(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { - if resp.Error != nil || resp.Value == nil { - return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error} - } - return concurrency.ErrorOr[*schema.OpenAIResponse]{ - Value: &schema.OpenAIResponse{ - Choices: []schema.Choice{ - { - Text: resp.Value.Response, - }, - }, - Usage: schema.OpenAIUsage{ - PromptTokens: resp.Value.Usage.Prompt, - CompletionTokens: resp.Value.Usage.Completion, - TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion, - }, - }, - } -} - -// TODO: Consider alternative names for this. -// The purpose of this struct is to hold a reference to the OpenAI request context information -// This keeps things simple within core/services/openai.go and allows consumers to "see" this information if they need it -type OpenAIRequestTraceID struct { - ID string - Created int -} - -// This type split out from core/backend/llm.go - I'm still not _totally_ sure about this, but it seems to make sense to keep the generic LLM code from the OpenAI specific higher level functionality -type OpenAIService struct { - bcl *config.BackendConfigLoader - ml *model.ModelLoader - appConfig *config.ApplicationConfig - llmbs *backend.LLMBackendService -} - -func NewOpenAIService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, llmbs *backend.LLMBackendService) *OpenAIService { - return &OpenAIService{ - bcl: bcl, - ml: ml, - appConfig: appConfig, - llmbs: llmbs, - } -} - -// Keeping in place as a reminder to POTENTIALLY ADD MORE VALIDATION HERE??? -func (oais *OpenAIService) getConfig(request *schema.OpenAIRequest) (*config.BackendConfig, *schema.OpenAIRequest, error) { - return oais.bcl.LoadBackendConfigForModelAndOpenAIRequest(request.Model, request, oais.appConfig) -} - -// TODO: It would be a lot less messy to make a return struct that had references to each of these channels -// INTENTIONALLY not doing that quite yet - I believe we need to let the references to unused channels die for the GC to automatically collect -- can we manually free()? -// finalResultsChannel is the primary async return path: one result for the entire request. -// promptResultsChannels is DUBIOUS. It's expected to be raw fan-out used within the function itself, but I am exposing for testing? One bundle of LLMResponseBundle per PromptString? Gets all N completions for a single prompt. -// completionsChannel is a channel that emits one *LLMResponse per generated completion, be that different prompts or N. Seems the most useful other than "entire request" Request is available to attempt tracing??? -// tokensChannel is a channel that emits one *LLMResponse per generated token. Let's see what happens! -func (oais *OpenAIService) Completion(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration { - return endpointConfiguration{ - SchemaObject: "text_completion", - TemplatePath: bc.TemplateConfig.Completion, - TemplateData: model.PromptTemplateData{ - SystemPrompt: bc.SystemPrompt, - }, - ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice { - return schema.Choice{ - Index: promptIndex, - FinishReason: "stop", - Text: resp.Response, - } - }, - CompletionMappingFn: simpleMapper, - TokenMappingFn: simpleMapper, - } - }, notifyOnPromptResult, notifyOnToken, nil) -} - -func (oais *OpenAIService) Edit(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - return oais.GenerateTextFromRequest(request, func(bc *config.BackendConfig, request *schema.OpenAIRequest) endpointConfiguration { - - return endpointConfiguration{ - SchemaObject: "edit", - TemplatePath: bc.TemplateConfig.Edit, - TemplateData: model.PromptTemplateData{ - SystemPrompt: bc.SystemPrompt, - Instruction: request.Instruction, - }, - ResultMappingFn: func(resp *backend.LLMResponse, promptIndex int) schema.Choice { - return schema.Choice{ - Index: promptIndex, - FinishReason: "stop", - Text: resp.Response, - } - }, - CompletionMappingFn: simpleMapper, - TokenMappingFn: simpleMapper, - } - }, notifyOnPromptResult, notifyOnToken, nil) -} - -func (oais *OpenAIService) Chat(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - return oais.GenerateFromMultipleMessagesChatRequest(request, notifyOnPromptResult, notifyOnToken, nil) -} - -func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest, endpointConfigFn endpointGenerationConfigurationFn, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], promptResultsChannels []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - if initialTraceID == nil { - traceID = &OpenAIRequestTraceID{ - ID: uuid.New().String(), - Created: int(time.Now().Unix()), - } - } else { - traceID = initialTraceID - } - - bc, request, err := oais.getConfig(request) - if err != nil { - log.Error().Err(err).Msgf("[oais::GenerateTextFromRequest] error getting configuration") - return - } - - if request.ResponseFormat.Type == "json_object" { - request.Grammar = grammar.JSONBNF - } - - bc.Grammar = request.Grammar - - if request.Stream && len(bc.PromptStrings) > 1 { - log.Warn().Msg("potentially cannot handle more than 1 `PromptStrings` when Streaming?") - } - - rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - finalResultChannel = rawFinalResultChannel - promptResultsChannels = []<-chan concurrency.ErrorOr[*backend.LLMResponseBundle]{} - var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - if notifyOnPromptResult { - rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - if notifyOnToken { - rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - - promptResultsChannelLock := sync.Mutex{} - - endpointConfig := endpointConfigFn(bc, request) - - if len(endpointConfig.TemplatePath) == 0 { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) { - endpointConfig.TemplatePath = bc.Model - } else { - log.Warn().Msgf("failed to find any template for %+v", request) - } - } - - setupWG := sync.WaitGroup{} - var prompts []string - if lPS := len(bc.PromptStrings); lPS > 0 { - setupWG.Add(lPS) - prompts = bc.PromptStrings - } else { - setupWG.Add(len(bc.InputStrings)) - prompts = bc.InputStrings - } - - var setupError error = nil - - for pI, p := range prompts { - - go func(promptIndex int, prompt string) { - if endpointConfig.TemplatePath != "" { - promptTemplateData := model.PromptTemplateData{ - Input: prompt, - } - err := mergo.Merge(promptTemplateData, endpointConfig.TemplateData, mergo.WithOverride) - if err == nil { - templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, endpointConfig.TemplatePath, promptTemplateData) - if err == nil { - prompt = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", prompt) - } - } - } - - log.Debug().Msgf("[OAIS GenerateTextFromRequest] Prompt: %q", prompt) - promptResultsChannel, completionChannels, tokenChannels, err := oais.llmbs.GenerateText(prompt, request, bc, - func(r *backend.LLMResponse) schema.Choice { - return endpointConfig.ResultMappingFn(r, promptIndex) - }, notifyOnPromptResult, notifyOnToken) - if err != nil { - log.Error().Msgf("Unable to generate text prompt: %q\nerr: %q", prompt, err) - promptResultsChannelLock.Lock() - setupError = errors.Join(setupError, err) - promptResultsChannelLock.Unlock() - setupWG.Done() - return - } - if notifyOnPromptResult { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(completionChannels, endpointConfig.CompletionMappingFn), rawCompletionsChannel, true) - } - if notifyOnToken { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, endpointConfig.TokenMappingFn), rawTokenChannel, true) - } - promptResultsChannelLock.Lock() - promptResultsChannels = append(promptResultsChannels, promptResultsChannel) - promptResultsChannelLock.Unlock() - setupWG.Done() - }(pI, p) - - } - setupWG.Wait() - - // If any of the setup goroutines experienced an error, quit early here. - if setupError != nil { - go func() { - log.Error().Err(setupError).Msgf("[OAIS GenerateTextFromRequest] caught an error during setup") - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError} - close(rawFinalResultChannel) - }() - return - } - - initialResponse := &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, - Object: endpointConfig.SchemaObject, - Usage: schema.OpenAIUsage{}, - } - - // utils.SliceOfChannelsRawMerger[[]schema.Choice](promptResultsChannels, rawFinalResultChannel, func(results []schema.Choice) (*schema.OpenAIResponse, error) { - concurrency.SliceOfChannelsReducer( - promptResultsChannels, rawFinalResultChannel, - func(iv concurrency.ErrorOr[*backend.LLMResponseBundle], result concurrency.ErrorOr[*schema.OpenAIResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { - - if iv.Error != nil { - result.Error = iv.Error - return result - } - result.Value.Usage.PromptTokens += iv.Value.Usage.Prompt - result.Value.Usage.CompletionTokens += iv.Value.Usage.Completion - result.Value.Usage.TotalTokens = result.Value.Usage.PromptTokens + result.Value.Usage.CompletionTokens - - result.Value.Choices = append(result.Value.Choices, iv.Value.Response...) - - return result - }, concurrency.ErrorOr[*schema.OpenAIResponse]{Value: initialResponse}, true) - - completionsChannel = rawCompletionsChannel - tokenChannel = rawTokenChannel - - return -} - -// TODO: For porting sanity, this is distinct from GenerateTextFromRequest and is _currently_ specific to Chat purposes -// this is not a final decision -- just a reality of moving a lot of parts at once -// / This has _become_ Chat which wasn't the goal... More cleanup in the future once it's stable? -func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *schema.OpenAIRequest, notifyOnPromptResult bool, notifyOnToken bool, initialTraceID *OpenAIRequestTraceID) ( - traceID *OpenAIRequestTraceID, finalResultChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], - completionsChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], tokenChannel <-chan concurrency.ErrorOr[*schema.OpenAIResponse], err error) { - - if initialTraceID == nil { - traceID = &OpenAIRequestTraceID{ - ID: uuid.New().String(), - Created: int(time.Now().Unix()), - } - } else { - traceID = initialTraceID - } - - bc, request, err := oais.getConfig(request) - if err != nil { - return - } - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - noActionName := "answer" - noActionDescription := "use this action to answer without performing any action" - - if bc.FunctionsConfig.NoActionFunctionName != "" { - noActionName = bc.FunctionsConfig.NoActionFunctionName - } - if bc.FunctionsConfig.NoActionDescriptionName != "" { - noActionDescription = bc.FunctionsConfig.NoActionDescriptionName - } - - if request.ResponseFormat.Type == "json_object" { - request.Grammar = grammar.JSONBNF - } - - bc.Grammar = request.Grammar - - processFunctions := false - funcs := grammar.Functions{} - // process functions if we have any defined or if we have a function call string - if len(request.Functions) > 0 && bc.ShouldUseFunctions() { - log.Debug().Msgf("Response needs to process functions") - - processFunctions = true - - noActionGrammar := grammar.Function{ - Name: noActionName, - Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, request.Functions...) - if !bc.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if bc.FunctionToCall() != "" { - funcs = funcs.Select(bc.FunctionToCall()) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - bc.Grammar = jsStruct.Grammar("", bc.FunctionsConfig.ParallelCalls) - } else if request.JSONFunctionGrammarObject != nil { - bc.Grammar = request.JSONFunctionGrammarObject.Grammar("", bc.FunctionsConfig.ParallelCalls) - } - - if request.Stream && processFunctions { - log.Warn().Msg("Streaming + Functions is highly experimental in this version") - } - - var predInput string - - if !bc.TemplateConfig.UseTokenizerTemplate || processFunctions { - - suppressConfigSystemPrompt := false - mess := []string{} - for messageIndex, i := range request.Messages { - var content string - role := i.Role - - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if (i.FunctionCall != nil || i.ToolCalls != nil) && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := bc.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := bc.Roles[role] - contentExists := i.Content != nil && i.StringContent != "" - - fcall := i.FunctionCall - if len(i.ToolCalls) > 0 { - fcall = i.ToolCalls - } - - // First attempt to populate content via a chat message specific template - if bc.TemplateConfig.ChatMessage != "" { - chatMessageData := model.ChatMessageTemplateData{ - SystemPrompt: bc.SystemPrompt, - Role: r, - RoleName: role, - Content: i.StringContent, - FunctionCall: fcall, - FunctionName: i.Name, - LastMessage: messageIndex == (len(request.Messages) - 1), - Function: bc.Grammar != "" && (messageIndex == (len(request.Messages) - 1)), - MessageIndex: messageIndex, - } - templatedChatMessage, err := oais.ml.EvaluateTemplateForChatMessage(bc.TemplateConfig.ChatMessage, chatMessageData) - if err != nil { - log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, bc.TemplateConfig.ChatMessage, err) - } else { - if templatedChatMessage == "" { - log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", bc.TemplateConfig.ChatMessage, chatMessageData) - continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf - } - log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) - content = templatedChatMessage - } - } - marshalAnyRole := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - marshalAny := func(f any) { - j, err := json.Marshal(f) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. - if content == "" { - if r != "" { - if contentExists { - content = fmt.Sprint(r, i.StringContent) - } - - if i.FunctionCall != nil { - marshalAnyRole(i.FunctionCall) - } - } else { - if contentExists { - content = fmt.Sprint(i.StringContent) - } - - if i.FunctionCall != nil { - marshalAny(i.FunctionCall) - } - - if i.ToolCalls != nil { - marshalAny(i.ToolCalls) - } - } - // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately - if contentExists && role == "system" { - suppressConfigSystemPrompt = true - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - templateFile := "" - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if oais.ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", bc.Model)) { - templateFile = bc.Model - } - - if bc.TemplateConfig.Chat != "" && !processFunctions { - templateFile = bc.TemplateConfig.Chat - } - - if bc.TemplateConfig.Functions != "" && processFunctions { - templateFile = bc.TemplateConfig.Functions - } - - if templateFile != "" { - templatedInput, err := oais.ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: bc.SystemPrompt, - SuppressSystemPrompt: suppressConfigSystemPrompt, - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - } - } - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", bc.Grammar) - } - - rawFinalResultChannel := make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - var rawCompletionsChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - var rawTokenChannel chan concurrency.ErrorOr[*schema.OpenAIResponse] - if notifyOnPromptResult { - rawCompletionsChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - if notifyOnToken { - rawTokenChannel = make(chan concurrency.ErrorOr[*schema.OpenAIResponse]) - } - - rawResultChannel, individualCompletionChannels, tokenChannels, err := oais.llmbs.GenerateText(predInput, request, bc, func(resp *backend.LLMResponse) schema.Choice { - return schema.Choice{ - Index: 0, // ??? - FinishReason: "stop", - Message: &schema.Message{ - Role: "assistant", - Content: resp.Response, - }, - } - }, notifyOnPromptResult, notifyOnToken) - - chatSimpleMappingFn := func(resp concurrency.ErrorOr[*backend.LLMResponse]) concurrency.ErrorOr[*schema.OpenAIResponse] { - if resp.Error != nil || resp.Value == nil { - return concurrency.ErrorOr[*schema.OpenAIResponse]{Error: resp.Error} - } - return concurrency.ErrorOr[*schema.OpenAIResponse]{ - Value: &schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Delta: &schema.Message{ - Role: "assistant", - Content: resp.Value.Response, - }, - Index: 0, - }, - }, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: resp.Value.Usage.Prompt, - CompletionTokens: resp.Value.Usage.Completion, - TotalTokens: resp.Value.Usage.Prompt + resp.Value.Usage.Completion, - }, - }, - } - } - - if notifyOnPromptResult { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(individualCompletionChannels, chatSimpleMappingFn), rawCompletionsChannel, true) - } - if notifyOnToken { - concurrency.SliceOfChannelsRawMergerWithoutMapping(concurrency.SliceOfChannelsTransformer(tokenChannels, chatSimpleMappingFn), rawTokenChannel, true) - } - - go func() { - rawResult := <-rawResultChannel - if rawResult.Error != nil { - log.Warn().Msgf("OpenAIService::processTools GenerateText error [DEBUG THIS?] %q", rawResult.Error) - return - } - llmResponseChoices := rawResult.Value.Response - - if processFunctions && len(llmResponseChoices) > 1 { - log.Warn().Msgf("chat functions response with %d choices in response, debug this?", len(llmResponseChoices)) - log.Debug().Msgf("%+v", llmResponseChoices) - } - - for _, result := range rawResult.Value.Response { - // If no functions, just return the raw result. - if !processFunctions { - - resp := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{result}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: rawResult.Value.Usage.Prompt, - CompletionTokens: rawResult.Value.Usage.Completion, - TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, - }, - } - - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp} - - continue - } - // At this point, things are function specific! - - // Oh no this can't be the right way to do this... but it works. Save us, mudler! - fString := fmt.Sprintf("%s", result.Message.Content) - results := parseFunctionCall(fString, bc.FunctionsConfig.ParallelCalls) - noActionToRun := (len(results) > 0 && results[0].name == noActionName) - - if noActionToRun { - log.Debug().Msg("-- noActionToRun branch --") - initialMessage := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: ""}}}, - Object: "stop", - } - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage} - - result, err := oais.handleQuestion(bc, request, results[0].arguments, predInput) - if err != nil { - log.Error().Msgf("error handling question: %s", err.Error()) - return - } - - resp := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: rawResult.Value.Usage.Prompt, - CompletionTokens: rawResult.Value.Usage.Completion, - TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, - }, - } - - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &resp} - - } else { - log.Debug().Msgf("[GenerateFromMultipleMessagesChatRequest] fnResultsBranch: %+v", results) - for i, ss := range results { - name, args := ss.name, ss.arguments - - initialMessage := schema.OpenAIResponse{ - ID: traceID.ID, - Created: traceID.Created, - Model: request.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{ - FinishReason: "function_call", - Message: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - Index: i, - ID: traceID.ID, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - Arguments: args, - }, - }, - }, - }}}, - Object: "chat.completion.chunk", - } - rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Value: &initialMessage} - } - } - } - - close(rawFinalResultChannel) - }() - - finalResultChannel = rawFinalResultChannel - completionsChannel = rawCompletionsChannel - tokenChannel = rawTokenChannel - return -} - -func (oais *OpenAIService) handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, args, prompt string) (string, error) { - log.Debug().Msgf("[handleQuestion called] nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(args), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = oais.llmbs.Finetune(*config, prompt, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - return message, nil - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU/GPU) another computation - config.Grammar = "" - images := []string{} - for _, m := range input.Messages { - images = append(images, m.StringImages...) - } - - resultChannel, _, err := oais.llmbs.Inference(input.Context, &backend.LLMRequest{ - Text: prompt, - Images: images, - RawMessages: input.Messages, // Experimental - }, config, false) - - if err != nil { - log.Error().Msgf("inference setup error: %s", err.Error()) - return "", err - } - - raw := <-resultChannel - if raw.Error != nil { - log.Error().Msgf("inference error: %q", raw.Error.Error()) - return "", err - } - if raw.Value == nil { - log.Warn().Msgf("nil inference response") - return "", nil - } - return oais.llmbs.Finetune(*config, prompt, raw.Value.Response), nil -} - -type funcCallResults struct { - name string - arguments string -} - -func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { - - results := []funcCallResults{} - - // TODO: use generics to avoid this code duplication - if multipleResults { - ss := []map[string]interface{}{} - s := utils.EscapeNewLines(llmresult) - json.Unmarshal([]byte(s), &ss) - - for _, s := range ss { - func_name, ok := s["function"] - if !ok { - continue - } - args, ok := s["arguments"] - if !ok { - continue - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - continue - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - } else { - // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s := utils.EscapeNewLines(llmresult) - if err := json.Unmarshal([]byte(s), &ss); err != nil { - log.Error().Msgf("error unmarshalling JSON: %s", err.Error()) - return results - } - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := ss["function"] - if !ok { - log.Debug().Msgf("ss[function] is not OK!, llm result: %q", llmresult) - return results - } - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - if !ok { - log.Debug().Msg("ss[arguments] is not OK!") - return results - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - log.Debug().Msgf("unexpected func_name: %+v", func_name) - return results - } - results = append(results, funcCallResults{name: funcName, arguments: string(d)}) - } - return results -} diff --git a/core/startup/startup.go b/core/startup/startup.go index 92ccaa9d..6298f034 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -4,21 +4,17 @@ import ( "fmt" "os" - "github.com/go-skynet/LocalAI/core" - "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/config" - openaiendpoint "github.com/go-skynet/LocalAI/core/http/endpoints/openai" // TODO: This is dubious. Fix this when splitting assistant api up. "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" + pkgStartup "github.com/go-skynet/LocalAI/pkg/startup" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) -// (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { -func Startup(opts ...config.AppOption) (*core.Application, error) { +func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) { options := config.NewApplicationConfig(opts...) zerolog.SetGlobalLevel(zerolog.InfoLevel) @@ -31,75 +27,68 @@ func Startup(opts ...config.AppOption) (*core.Application, error) { // Make sure directories exists if options.ModelPath == "" { - return nil, fmt.Errorf("options.ModelPath cannot be empty") + return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty") } err := os.MkdirAll(options.ModelPath, 0755) if err != nil { - return nil, fmt.Errorf("unable to create ModelPath: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err) } if options.ImageDir != "" { err := os.MkdirAll(options.ImageDir, 0755) if err != nil { - return nil, fmt.Errorf("unable to create ImageDir: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err) } } if options.AudioDir != "" { err := os.MkdirAll(options.AudioDir, 0755) if err != nil { - return nil, fmt.Errorf("unable to create AudioDir: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err) } } if options.UploadDir != "" { err := os.MkdirAll(options.UploadDir, 0755) if err != nil { - return nil, fmt.Errorf("unable to create UploadDir: %q", err) - } - } - if options.ConfigsDir != "" { - err := os.MkdirAll(options.ConfigsDir, 0755) - if err != nil { - return nil, fmt.Errorf("unable to create ConfigsDir: %q", err) + return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err) } } - // Load config jsons - utils.LoadConfig(options.UploadDir, openaiendpoint.UploadedFilesFile, &openaiendpoint.UploadedFiles) - utils.LoadConfig(options.ConfigsDir, openaiendpoint.AssistantsConfigFile, &openaiendpoint.Assistants) - utils.LoadConfig(options.ConfigsDir, openaiendpoint.AssistantsFileConfigFile, &openaiendpoint.AssistantFiles) + // + pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) - app := createApplication(options) + cl := config.NewBackendConfigLoader() + ml := model.NewModelLoader(options.ModelPath) - services.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) + configLoaderOpts := options.ToConfigLoaderOptions() - if err := app.BackendConfigLoader.LoadBackendConfigsFromPath(options.ModelPath, app.ApplicationConfig.ToConfigLoaderOptions()...); err != nil { + if err := cl.LoadBackendConfigsFromPath(options.ModelPath, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config files") } if options.ConfigFile != "" { - if err := app.BackendConfigLoader.LoadBackendConfigFile(options.ConfigFile, app.ApplicationConfig.ToConfigLoaderOptions()...); err != nil { + if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config file") } } - if err := app.BackendConfigLoader.Preload(options.ModelPath); err != nil { + if err := cl.Preload(options.ModelPath); err != nil { log.Error().Err(err).Msg("error downloading models") } if options.PreloadJSONModels != "" { - if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, app.BackendConfigLoader, options.Galleries); err != nil { - return nil, err + if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { + return nil, nil, nil, err } } if options.PreloadModelsFromPath != "" { - if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, app.BackendConfigLoader, options.Galleries); err != nil { - return nil, err + if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { + return nil, nil, nil, err } } if options.Debug { - for _, v := range app.BackendConfigLoader.ListBackendConfigs() { - cfg, _ := app.BackendConfigLoader.GetBackendConfig(v) + for _, v := range cl.ListBackendConfigs() { + cfg, _ := cl.GetBackendConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } @@ -117,17 +106,17 @@ func Startup(opts ...config.AppOption) (*core.Application, error) { go func() { <-options.Context.Done() log.Debug().Msgf("Context canceled, shutting down") - app.ModelLoader.StopAllGRPC() + ml.StopAllGRPC() }() if options.WatchDog { wd := model.NewWatchDog( - app.ModelLoader, + ml, options.WatchDogBusyTimeout, options.WatchDogIdleTimeout, options.WatchDogBusy, options.WatchDogIdle) - app.ModelLoader.SetWatchDog(wd) + ml.SetWatchDog(wd) go wd.Run() go func() { <-options.Context.Done() @@ -137,35 +126,5 @@ func Startup(opts ...config.AppOption) (*core.Application, error) { } log.Info().Msg("core/startup process completed!") - return app, nil -} - -// In Lieu of a proper DI framework, this function wires up the Application manually. -// This is in core/startup rather than core/state.go to keep package references clean! -func createApplication(appConfig *config.ApplicationConfig) *core.Application { - app := &core.Application{ - ApplicationConfig: appConfig, - BackendConfigLoader: config.NewBackendConfigLoader(), - ModelLoader: model.NewModelLoader(appConfig.ModelPath), - } - - var err error - - app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - - app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.GalleryService = services.NewGalleryService(app.ApplicationConfig.ModelPath) - app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) - app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService) - - app.LocalAIMetricsService, err = services.NewLocalAIMetricsService() - if err != nil { - log.Warn().Msg("Unable to initialize LocalAIMetricsService - non-fatal, optional service") - } - - return app + return cl, ml, options, nil } diff --git a/core/state.go b/core/state.go deleted file mode 100644 index cf0d614b..00000000 --- a/core/state.go +++ /dev/null @@ -1,41 +0,0 @@ -package core - -import ( - "github.com/go-skynet/LocalAI/core/backend" - "github.com/go-skynet/LocalAI/core/config" - "github.com/go-skynet/LocalAI/core/services" - "github.com/go-skynet/LocalAI/pkg/model" -) - -// TODO: Can I come up with a better name or location for this? -// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy -// Perhaps a proper DI system is worth it in the future, but for now keep things simple. -type Application struct { - - // Application-Level Config - ApplicationConfig *config.ApplicationConfig - // ApplicationState *ApplicationState - - // Core Low-Level Services - BackendConfigLoader *config.BackendConfigLoader - ModelLoader *model.ModelLoader - - // Backend Services - EmbeddingsBackendService *backend.EmbeddingsBackendService - ImageGenerationBackendService *backend.ImageGenerationBackendService - LLMBackendService *backend.LLMBackendService - TranscriptionBackendService *backend.TranscriptionBackendService - TextToSpeechBackendService *backend.TextToSpeechBackendService - - // LocalAI System Services - BackendMonitorService *services.BackendMonitorService - GalleryService *services.GalleryService - ListModelsService *services.ListModelsService - LocalAIMetricsService *services.LocalAIMetricsService - OpenAIService *services.OpenAIService -} - -// TODO [NEXT PR?]: Break up ApplicationConfig. -// Migrate over stuff that is not set via config at all - especially runtime stuff -type ApplicationState struct { -} diff --git a/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru b/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru deleted file mode 100644 index c33bafe1..00000000 --- a/examples/bruno/LocalAI Test Requests/llm text/-completions Stream.bru +++ /dev/null @@ -1,25 +0,0 @@ -meta { - name: -completions Stream - type: http - seq: 4 -} - -post { - url: {{PROTOCOL}}{{HOST}}:{{PORT}}/completions - body: json - auth: none -} - -headers { - Content-Type: application/json -} - -body:json { - { - "model": "{{DEFAULT_MODEL}}", - "prompt": "function downloadFile(string url, string outputPath) {", - "max_tokens": 256, - "temperature": 0.5, - "stream": true - } -} diff --git a/pkg/concurrency/concurrency.go b/pkg/concurrency/concurrency.go deleted file mode 100644 index 324e8cc5..00000000 --- a/pkg/concurrency/concurrency.go +++ /dev/null @@ -1,135 +0,0 @@ -package concurrency - -import ( - "sync" -) - -// TODO: closeWhenDone bool parameter :: -// It currently is experimental, and therefore exists. -// Is there ever a situation to use false? - -// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of a second type. -// mappingFn allows the caller to convert from the input type to the output type -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsRawMerger[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan IndividualResultType, outputChannel chan<- OutputResultType, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup { - var wg sync.WaitGroup - wg.Add(len(individualResultChannels)) - mergingFn := func(c <-chan IndividualResultType) { - for r := range c { - mr, err := mappingFn(r) - if err == nil { - outputChannel <- mr - } - } - wg.Done() - } - for _, irc := range individualResultChannels { - go mergingFn(irc) - } - if closeWhenDone { - go func() { - wg.Wait() - close(outputChannel) - }() - } - - return &wg -} - -// This function is used to merge the results of a slice of channels of a specific result type down to a single result channel of THE SAME TYPE. -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsRawMergerWithoutMapping[ResultType any](individualResultsChannels []<-chan ResultType, outputChannel chan<- ResultType, closeWhenDone bool) *sync.WaitGroup { - return SliceOfChannelsRawMerger(individualResultsChannels, outputChannel, func(v ResultType) (ResultType, error) { return v, nil }, closeWhenDone) -} - -// This function is used to merge the results of a slice of channels of a specific result type down to a single succcess result channel of a second type, and an error channel -// mappingFn allows the caller to convert from the input type to the output type -// This variant is designed to be aware of concurrency.ErrorOr[T], splitting successes from failures. -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsMergerWithErrors[IndividualResultType any, OutputResultType any](individualResultChannels []<-chan ErrorOr[IndividualResultType], successChannel chan<- OutputResultType, errorChannel chan<- error, mappingFn func(IndividualResultType) (OutputResultType, error), closeWhenDone bool) *sync.WaitGroup { - var wg sync.WaitGroup - wg.Add(len(individualResultChannels)) - mergingFn := func(c <-chan ErrorOr[IndividualResultType]) { - for r := range c { - if r.Error != nil { - errorChannel <- r.Error - } else { - mv, err := mappingFn(r.Value) - if err != nil { - errorChannel <- err - } else { - successChannel <- mv - } - } - } - wg.Done() - } - for _, irc := range individualResultChannels { - go mergingFn(irc) - } - if closeWhenDone { - go func() { - wg.Wait() - close(successChannel) - close(errorChannel) - }() - } - return &wg -} - -// This function is used to reduce down the results of a slice of channels of a specific result type down to a single result value of a second type. -// reducerFn allows the caller to convert from the input type to the output type -// if closeWhenDone is set to true, the output channel will be closed when all individual result channels of the slice have been closed - otherwise it will be left open for future use. -// The same WaitGroup used to trigger that optional closing is returned for any other synchronization purposes. -func SliceOfChannelsReducer[InputResultType any, OutputResultType any](individualResultsChannels []<-chan InputResultType, outputChannel chan<- OutputResultType, - reducerFn func(iv InputResultType, ov OutputResultType) OutputResultType, initialValue OutputResultType, closeWhenDone bool) (wg *sync.WaitGroup) { - wg = &sync.WaitGroup{} - wg.Add(len(individualResultsChannels)) - reduceLock := sync.Mutex{} - reducingFn := func(c <-chan InputResultType) { - for iv := range c { - reduceLock.Lock() - initialValue = reducerFn(iv, initialValue) - reduceLock.Unlock() - } - wg.Done() - } - for _, irc := range individualResultsChannels { - go reducingFn(irc) - } - go func() { - wg.Wait() - outputChannel <- initialValue - if closeWhenDone { - close(outputChannel) - } - }() - return wg -} - -// This function is primarily designed to be used in combination with the above utility functions. -// A slice of input result channels of a specific type is provided, along with a function to map those values to another type -// A slice of output result channels is returned, where each value is mapped as it comes in. -// The order of the slice will be retained. -func SliceOfChannelsTransformer[InputResultType any, OutputResultType any](inputChanels []<-chan InputResultType, mappingFn func(v InputResultType) OutputResultType) (outputChannels []<-chan OutputResultType) { - rawOutputChannels := make([]<-chan OutputResultType, len(inputChanels)) - - transformingFn := func(ic <-chan InputResultType, oc chan OutputResultType) { - for iv := range ic { - oc <- mappingFn(iv) - } - close(oc) - } - - for ci, c := range inputChanels { - roc := make(chan OutputResultType) - go transformingFn(c, roc) - rawOutputChannels[ci] = roc - } - - outputChannels = rawOutputChannels - return -} diff --git a/pkg/concurrency/concurrency_test.go b/pkg/concurrency/concurrency_test.go deleted file mode 100644 index fedd74be..00000000 --- a/pkg/concurrency/concurrency_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package concurrency_test - -// TODO: noramlly, these go in utils_tests, right? Why does this cause problems only in pkg/utils? - -import ( - "fmt" - "slices" - - . "github.com/go-skynet/LocalAI/pkg/concurrency" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("utils/concurrency tests", func() { - It("SliceOfChannelsReducer works", func() { - individualResultsChannels := []<-chan int{} - initialValue := 0 - for i := 0; i < 3; i++ { - c := make(chan int) - go func(i int, c chan int) { - for ii := 1; ii < 4; ii++ { - c <- (i * ii) - } - close(c) - }(i, c) - individualResultsChannels = append(individualResultsChannels, c) - } - Expect(len(individualResultsChannels)).To(Equal(3)) - finalResultChannel := make(chan int) - wg := SliceOfChannelsReducer[int, int](individualResultsChannels, finalResultChannel, func(input int, val int) int { - return val + input - }, initialValue, true) - - Expect(wg).ToNot(BeNil()) - - result := <-finalResultChannel - - Expect(result).ToNot(Equal(0)) - Expect(result).To(Equal(18)) - }) - - It("SliceOfChannelsRawMergerWithoutMapping works", func() { - individualResultsChannels := []<-chan int{} - for i := 0; i < 3; i++ { - c := make(chan int) - go func(i int, c chan int) { - for ii := 1; ii < 4; ii++ { - c <- (i * ii) - } - close(c) - }(i, c) - individualResultsChannels = append(individualResultsChannels, c) - } - Expect(len(individualResultsChannels)).To(Equal(3)) - outputChannel := make(chan int) - wg := SliceOfChannelsRawMergerWithoutMapping(individualResultsChannels, outputChannel, true) - Expect(wg).ToNot(BeNil()) - outputSlice := []int{} - for v := range outputChannel { - outputSlice = append(outputSlice, v) - } - Expect(len(outputSlice)).To(Equal(9)) - slices.Sort(outputSlice) - Expect(outputSlice[0]).To(BeZero()) - Expect(outputSlice[3]).To(Equal(1)) - Expect(outputSlice[8]).To(Equal(6)) - }) - - It("SliceOfChannelsTransformer works", func() { - individualResultsChannels := []<-chan int{} - for i := 0; i < 3; i++ { - c := make(chan int) - go func(i int, c chan int) { - for ii := 1; ii < 4; ii++ { - c <- (i * ii) - } - close(c) - }(i, c) - individualResultsChannels = append(individualResultsChannels, c) - } - Expect(len(individualResultsChannels)).To(Equal(3)) - mappingFn := func(i int) string { - return fmt.Sprintf("$%d", i) - } - - outputChannels := SliceOfChannelsTransformer(individualResultsChannels, mappingFn) - Expect(len(outputChannels)).To(Equal(3)) - rSlice := []string{} - for ii := 1; ii < 4; ii++ { - for i := 0; i < 3; i++ { - res := <-outputChannels[i] - rSlice = append(rSlice, res) - } - } - slices.Sort(rSlice) - Expect(rSlice[0]).To(Equal("$0")) - Expect(rSlice[3]).To(Equal("$1")) - Expect(rSlice[8]).To(Equal("$6")) - }) -}) diff --git a/pkg/concurrency/types.go b/pkg/concurrency/types.go deleted file mode 100644 index 76081ba3..00000000 --- a/pkg/concurrency/types.go +++ /dev/null @@ -1,6 +0,0 @@ -package concurrency - -type ErrorOr[T any] struct { - Value T - Error error -} diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 49a6b1bd..8fb8c39d 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -41,7 +41,7 @@ type Backend interface { PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) - AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) + AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index c0b4bc34..0af5d94f 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -53,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) { - return schema.TranscriptionResult{}, fmt.Errorf("unimplemented") +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { + return schema.Result{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 0e0e56c7..882db12a 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -210,7 +210,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -231,7 +231,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques if err != nil { return nil, err } - tresult := &schema.TranscriptionResult{} + tresult := &schema.Result{} for _, s := range res.Segments { tks := []int{} for _, t := range s.Tokens { diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index b4ba4884..73b185a3 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -53,12 +53,12 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc. return e.s.TTS(ctx, in) } -func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { +func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { r, err := e.s.AudioTranscription(ctx, in) if err != nil { return nil, err } - tr := &schema.TranscriptionResult{} + tr := &schema.Result{} for _, s := range r.Segments { var tks []int for _, t := range s.Tokens { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index aa7a3fbc..4d06544d 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -15,7 +15,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) + AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) TTS(*pb.TTSRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 617d8f62..5d9808a4 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -81,7 +81,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string if _, err := os.Stat(uri); err == nil { serverAddress, err := getFreeAddress() if err != nil { - return "", fmt.Errorf("%s failed allocating free ports: %s", backend, err.Error()) + return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) } // Make sure the process is executable if err := ml.startProcess(uri, o.model, serverAddress); err != nil { @@ -134,7 +134,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string if !ready { log.Debug().Msgf("GRPC Service NOT ready") - return "", fmt.Errorf("%s grpc service not ready", backend) + return "", fmt.Errorf("grpc service not ready") } options := *o.gRPCOptions @@ -145,10 +145,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options) if err != nil { - return "", fmt.Errorf("\"%s\" could not load model: %w", backend, err) + return "", fmt.Errorf("could not load model: %w", err) } if !res.Success { - return "", fmt.Errorf("\"%s\" could not load model (no success): %s", backend, res.Message) + return "", fmt.Errorf("could not load model (no success): %s", res.Message) } return client, nil diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go new file mode 100644 index 00000000..b09516a7 --- /dev/null +++ b/pkg/startup/model_preload.go @@ -0,0 +1,85 @@ +package startup + +import ( + "errors" + "os" + "path/filepath" + + "github.com/go-skynet/LocalAI/embedded" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" +) + +// PreloadModelsConfigurations will preload models from the given list of URLs +// It will download the model if it is not already present in the model path +// It will also try to resolve if the model is an embedded model YAML configuration +func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { + for _, url := range models { + + // As a best effort, try to resolve the model from the remote library + // if it's not resolved we try with the other method below + if modelLibraryURL != "" { + lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) + if err == nil { + if lib[url] != "" { + log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) + url = lib[url] + } + } + } + + url = embedded.ModelShortURL(url) + switch { + case embedded.ExistsInModelsLibrary(url): + modelYAML, err := embedded.ResolveContent(url) + // If we resolve something, just save it to disk and continue + if err != nil { + log.Error().Err(err).Msg("error resolving model content") + continue + } + + log.Debug().Msgf("[startup] resolved embedded model: %s", url) + md5Name := utils.MD5(url) + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { + log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") + } + case downloader.LooksLikeURL(url): + log.Debug().Msgf("[startup] resolved model to download: %s", url) + + // md5 of model name + md5Name := utils.MD5(url) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + }) + if err != nil { + log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + } + } + default: + if _, err := os.Stat(url); err == nil { + log.Debug().Msgf("[startup] resolved local model: %s", url) + // copy to modelPath + md5Name := utils.MD5(url) + + modelYAML, err := os.ReadFile(url) + if err != nil { + log.Error().Err(err).Str("filepath", url).Msg("error reading model definition") + continue + } + + modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" + if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { + log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") + } + } else { + log.Warn().Msgf("[startup] failed resolving model '%s'", url) + } + } + } +} diff --git a/core/services/model_preload_test.go b/pkg/startup/model_preload_test.go similarity index 96% rename from core/services/model_preload_test.go rename to pkg/startup/model_preload_test.go index fc65d565..63a8f8b0 100644 --- a/core/services/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -1,14 +1,13 @@ -package services_test +package startup_test import ( "fmt" "os" "path/filepath" + . "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" - . "github.com/go-skynet/LocalAI/core/services" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/pkg/utils/base64.go b/pkg/utils/base64.go deleted file mode 100644 index 769d8a88..00000000 --- a/pkg/utils/base64.go +++ /dev/null @@ -1,50 +0,0 @@ -package utils - -import ( - "encoding/base64" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -var base64DownloadClient http.Client = http.Client{ - Timeout: 30 * time.Second, -} - -// this function check if the string is an URL, if it's an URL downloads the image in memory -// encodes it in base64 and returns the base64 string - -// This may look weird down in pkg/utils while it is currently only used in core/config -// -// but I believe it may be useful for MQTT as well in the near future, so I'm -// extracting it while I'm thinking of it. -func GetImageURLAsBase64(s string) (string, error) { - if strings.HasPrefix(s, "http") { - // download the image - resp, err := base64DownloadClient.Get(s) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // read the image data into memory - data, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // encode the image data in base64 - encoded := base64.StdEncoding.EncodeToString(data) - - // return the base64 string - return encoded, nil - } - - // if the string instead is prefixed with "data:image/jpeg;base64,", drop it - if strings.HasPrefix(s, "data:image/jpeg;base64,") { - return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil - } - return "", fmt.Errorf("not valid string") -}