diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index 0310347e..53df785b 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. model.WithContext(o.Context), } + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { diff --git a/api/backend/image.go b/api/backend/image.go index a631b3b4..9e32d1db 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat return nil, fmt.Errorf("endpoint only working with stablediffusion models") } - inferenceModel, err := loader.BackendLoader( + opts := []model.Option{ model.WithBackendString(c.Backend), model.WithAssetDir(o.AssetsDestination), model.WithThreads(uint32(c.Threads)), model.WithContext(o.Context), model.WithModelFile(c.ImageGenerationAssets), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + inferenceModel, err := loader.BackendLoader( + opts..., ) if err != nil { return nil, err diff --git a/api/backend/llm.go b/api/backend/llm.go index 8fcd6daf..593eea3c 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -1,14 +1,17 @@ package backend import ( + "os" "regexp" "strings" "sync" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" ) func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { @@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt model.WithContext(o.Context), } + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + if c.Backend != "" { + opts = append(opts, model.WithBackendString(c.Backend)) + } + + // Check if the modelFile exists, if it doesn't try to load it from the gallery + if o.AutoloadGalleries { // experimental + if _, err := os.Stat(modelFile); os.IsNotExist(err) { + utils.ResetDownloadTimers() + // if we failed to load the model, we try to download it + err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + if err != nil { + return nil, err + } + } + } + if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { - opts = append(opts, model.WithBackendString(c.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } + if err != nil { return nil, err } diff --git a/api/backend/transcript.go b/api/backend/transcript.go new file mode 100644 index 00000000..b2f25012 --- /dev/null +++ b/api/backend/transcript.go @@ -0,0 +1,42 @@ +package backend + +import ( + "context" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) { + opts := []model.Option{ + model.WithBackendString(model.WhisperBackend), + model.WithModelFile(c.Model), + model.WithContext(o.Context), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + whisperModel, err := o.Loader.BackendLoader(opts...) + if err != nil { + return nil, err + } + + if whisperModel == nil { + return nil, fmt.Errorf("could not load whisper model") + } + + return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + Dst: audio, + Language: language, + Threads: uint32(c.Threads), + }) +} diff --git a/api/backend/tts.go b/api/backend/tts.go new file mode 100644 index 00000000..ac491e25 --- /dev/null +++ b/api/backend/tts.go @@ -0,0 +1,72 @@ +package backend + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" +) + +func generateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext + + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName + } + + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) + } +} + +func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { + opts := []model.Option{ + model.WithBackendString(model.PiperBackend), + model.WithModelFile(modelFile), + model.WithContext(o.Context), + model.WithAssetDir(o.AssetsDestination), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + piperModel, err := o.Loader.BackendLoader(opts...) + if err != nil { + return "", nil, err + } + + if piperModel == nil { + return "", nil, fmt.Errorf("could not load piper model") + } + + if err := os.MkdirAll(o.AudioDir, 0755); err != nil { + return "", nil, fmt.Errorf("failed creating audio directory: %s", err) + } + + fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") + filePath := filepath.Join(o.AudioDir, fileName) + + modelPath := filepath.Join(o.Loader.ModelPath, modelFile) + + if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { + return "", nil, err + } + + res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ + Text: text, + Model: modelPath, + Dst: filePath, + }) + + return filePath, res, err +} diff --git a/api/localai/gallery.go b/api/localai/gallery.go index feae2942..ef4be145 100644 --- a/api/localai/gallery.go +++ b/api/localai/gallery.go @@ -4,13 +4,15 @@ import ( "context" "fmt" "os" + "strings" "sync" - "time" json "github.com/json-iterator/go" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/rs/zerolog/log" @@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { case <-c.Done(): return case op := <-g.C: + utils.ResetDownloadTimers() + g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) // updates the status with an error @@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { // displayDownload displays the download progress progressCallback := func(fileName string, current string, total string, percentage float64) { g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - displayDownload(fileName, current, total, percentage) + utils.DisplayDownloadFunction(fileName, current, total, percentage) } var err error // if the request contains a gallery name, we apply the gallery from the gallery list if op.galleryName != "" { - err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + if strings.Contains(op.galleryName, "@") { + err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } else { + err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } } else { err = prepareModel(g.modelPath, op.req, cm, progressCallback) } @@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { }() } -var lastProgress time.Time = time.Now() -var startTime time.Time = time.Now() - -func displayDownload(fileName string, current string, total string, percentage float64) { - currentTime := time.Now() - - if currentTime.Sub(lastProgress) >= 5*time.Second { - - lastProgress = currentTime - - // calculate ETA based on percentage and elapsed time - var eta time.Duration - if percentage > 0 { - elapsed := currentTime.Sub(startTime) - eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) - } - - if total != "" { - log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) - } else { - log.Debug().Msgf("Downloading: %s", current) - } - } -} - type galleryModel struct { gallery.GalleryModel ID string `json:"id"` @@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler } for _, r := range requests { + utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload) + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) } else { - err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload) + err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) } } diff --git a/api/localai/localai.go b/api/localai/localai.go index 7c57c92b..49f77805 100644 --- a/api/localai/localai.go +++ b/api/localai/localai.go @@ -1,17 +1,10 @@ package localai import ( - "context" - "fmt" - "os" - "path/filepath" - + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" ) @@ -20,22 +13,6 @@ type TTSRequest struct { Input string `json:"input" yaml:"input"` } -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext - - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) - } -} - func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { @@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return err } - piperModel, err := o.Loader.BackendLoader( - model.WithBackendString(model.PiperBackend), - model.WithModelFile(input.Model), - model.WithContext(o.Context), - model.WithAssetDir(o.AssetsDestination)) + filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o) if err != nil { return err } - - if piperModel == nil { - return fmt.Errorf("could not load piper model") - } - - if err := os.MkdirAll(o.AudioDir, 0755); err != nil { - return fmt.Errorf("failed creating audio directory: %s", err) - } - - fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") - filePath := filepath.Join(o.AudioDir, fileName) - - modelPath := filepath.Join(o.Loader.ModelPath, input.Model) - - if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { - return err - } - - if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ - Text: input.Input, - Model: modelPath, - Dst: filePath, - }); err != nil { - return err - } - return c.Download(filePath) } } diff --git a/api/openai/transcription.go b/api/openai/transcription.go index 346693c1..4b4a65e0 100644 --- a/api/openai/transcription.go +++ b/api/openai/transcription.go @@ -1,7 +1,6 @@ package openai import ( - "context" "fmt" "io" "net/http" @@ -9,10 +8,9 @@ import ( "path" "path/filepath" + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Audio file copied to: %+v", dst) - whisperModel, err := o.Loader.BackendLoader( - model.WithBackendString(model.WhisperBackend), - model.WithModelFile(config.Model), - model.WithContext(o.Context), - model.WithThreads(uint32(config.Threads)), - model.WithAssetDir(o.AssetsDestination)) - if err != nil { - return err - } - - if whisperModel == nil { - return fmt.Errorf("could not load whisper model") - } - - tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ - Dst: dst, - Language: input.Language, - Threads: uint32(config.Threads), - }) + tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) if err != nil { return err } diff --git a/api/options/options.go b/api/options/options.go index 06029b04..b3269470 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -28,6 +28,10 @@ type Option struct { BackendAssets embed.FS AssetsDestination string + + ExternalGRPCBackends map[string]string + + AutoloadGalleries bool } type AppOption func(*Option) @@ -53,6 +57,19 @@ func WithCors(b bool) AppOption { } } +var EnableGalleriesAutoload = func(o *Option) { + o.AutoloadGalleries = true +} + +func WithExternalBackend(name string, uri string) AppOption { + return func(o *Option) { + if o.ExternalGRPCBackends == nil { + o.ExternalGRPCBackends = make(map[string]string) + } + o.ExternalGRPCBackends[name] = uri + } +} + func WithCorsAllowOrigins(b string) AppOption { return func(o *Option) { o.CORSAllowOrigins = b diff --git a/main.go b/main.go index 3f534b0a..2cb86271 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" api "github.com/go-skynet/LocalAI/api" @@ -40,6 +41,10 @@ func main() { Name: "f16", EnvVars: []string{"F16"}, }, + &cli.BoolFlag{ + Name: "autoload-galleries", + EnvVars: []string{"AUTOLOAD_GALLERIES"}, + }, &cli.BoolFlag{ Name: "debug", EnvVars: []string{"DEBUG"}, @@ -108,6 +113,11 @@ func main() { EnvVars: []string{"BACKEND_ASSETS_PATH"}, Value: "/tmp/localai/backend_data", }, + &cli.StringSliceFlag{ + Name: "external-grpc-backends", + Usage: "A list of external grpc backends", + EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"}, + }, &cli.IntFlag{ Name: "context-size", Usage: "Default context size of the model", @@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit UsageText: `local-ai [options]`, Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { - app, err := api.App( + + opts := []options.AppOption{ options.WithConfigFile(ctx.String("config-file")), options.WithJSONStringPreload(ctx.String("preload-models")), options.WithYAMLConfigPreload(ctx.String("preload-models-config")), @@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithThreads(ctx.Int("threads")), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - options.WithUploadLimitMB(ctx.Int("upload-limit"))) + options.WithUploadLimitMB(ctx.Int("upload-limit")), + } + + externalgRPC := ctx.StringSlice("external-grpc-backends") + // split ":" to get backend name and the uri + for _, v := range externalgRPC { + backend := v[:strings.IndexByte(v, ':')] + uri := v[strings.IndexByte(v, ':')+1:] + opts = append(opts, options.WithExternalBackend(backend, uri)) + } + + if ctx.Bool("autoload-galleries") { + opts = append(opts, options.EnableGalleriesAutoload) + } + + app, err := api.App(opts...) if err != nil { return err } diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index 8e085929..6fe05ed9 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -18,23 +18,15 @@ type Gallery struct { // Installs a model from the gallery (galleryname@modelname) func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { - - // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. - name = strings.ReplaceAll(name, string(os.PathSeparator), "__") - - models, err := AvailableGalleryModels(galleries, basePath) - if err != nil { - return err - } - applyModel := func(model *GalleryModel) error { config, err := GetGalleryConfigFromURL(model.URL) if err != nil { return err } + installName := model.Name if req.Name != "" { - model.Name = req.Name + installName = req.Name } config.Files = append(config.Files, req.AdditionalFiles...) @@ -45,20 +37,58 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, return err } - if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { + if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil { return err } return nil } + models, err := AvailableGalleryModels(galleries, basePath) + if err != nil { + return err + } + + model, err := FindGallery(models, name) + if err != nil { + return err + } + + return applyModel(model) +} + +func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) { + // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + for _, model := range models { if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { - return applyModel(model) + return model, nil + } + } + return nil, fmt.Errorf("no gallery found with name %q", name) +} + +// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins) +func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { + models, err := AvailableGalleryModels(galleries, basePath) + if err != nil { + return err + } + + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + var model *GalleryModel + for _, m := range models { + if name == m.Name { + model = m } } - return fmt.Errorf("no model found with name %q", name) + if model == nil { + return fmt.Errorf("no model found with name %q", name) + } + + return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus) } // List available models diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index d3b4bb3f..32c9afc6 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -19,8 +19,6 @@ import ( process "github.com/mudler/go-processmanager" ) -const tokenizerSuffix = ".tokenizer.json" - const ( LlamaBackend = "llama" BloomzBackend = "bloomz" @@ -45,7 +43,6 @@ const ( StableDiffusionBackend = "stablediffusion" PiperBackend = "piper" LCHuggingFaceBackend = "langchain-huggingface" - //GGLLMFalconBackend = "falcon" ) var AutoLoadBackends []string = []string{ @@ -75,75 +72,116 @@ func (ml *ModelLoader) StopGRPC() { } } +func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error { + // Make sure the process is executable + if err := os.Chmod(grpcProcess, 0755); err != nil { + return err + } + + log.Debug().Msgf("Loading GRPC Process", grpcProcess) + + log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress) + + grpcControlProcess := process.New( + process.WithTemporaryStateDir(), + process.WithName(grpcProcess), + process.WithArgs("--addr", serverAddress)) + + ml.grpcProcesses[id] = grpcControlProcess + + if err := grpcControlProcess.Run(); err != nil { + return err + } + + log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir()) + // clean up process + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + grpcControlProcess.Stop() + }() + + go func() { + t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stderr") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() + go func() { + t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stdout") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() + + return nil +} + // starts the grpcModelProcess for the backend, and returns a grpc client // It also loads the model func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { return func(s string) (*grpc.Client, error) { log.Debug().Msgf("Loading GRPC Model", backend, *o) - grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + var client *grpc.Client - // Check if the file exists - if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { - return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) - } - - // Make sure the process is executable - if err := os.Chmod(grpcProcess, 0755); err != nil { - return nil, err - } - - log.Debug().Msgf("Loading GRPC Process", grpcProcess) - port, err := freeport.GetFreePort() - if err != nil { - return nil, err - } - - serverAddress := fmt.Sprintf("localhost:%d", port) - - log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress) - - grpcControlProcess := process.New( - process.WithTemporaryStateDir(), - process.WithName(grpcProcess), - process.WithArgs("--addr", serverAddress)) - - ml.grpcProcesses[o.modelFile] = grpcControlProcess - - if err := grpcControlProcess.Run(); err != nil { - return nil, err - } - - // clean up process - go func() { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - grpcControlProcess.Stop() - }() - - go func() { - t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) + getFreeAddress := func() (string, error) { + port, err := freeport.GetFreePort() if err != nil { - log.Debug().Msgf("Could not tail stderr") + return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + return fmt.Sprintf("127.0.0.1:%d", port), nil + } + + // Check if the backend is provided as external + if uri, ok := o.externalBackends[backend]; ok { + log.Debug().Msgf("Loading external backend: %s", uri) + // check if uri is a file or a address + if _, err := os.Stat(uri); err == nil { + serverAddress, err := getFreeAddress() + if err != nil { + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) + } + // Make sure the process is executable + if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil { + return nil, err + } + + log.Debug().Msgf("GRPC Service Started") + + client = grpc.NewClient(serverAddress) + } else { + // address + client = grpc.NewClient(uri) } - }() - go func() { - t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + } else { + grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + // Check if the file exists + if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { + return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) + } + + serverAddress, err := getFreeAddress() if err != nil { - log.Debug().Msgf("Could not tail stdout") + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + + // Make sure the process is executable + if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil { + return nil, err } - }() - log.Debug().Msgf("GRPC Service Started") + log.Debug().Msgf("GRPC Service Started") - client := grpc.NewClient(serverAddress) + client = grpc.NewClient(serverAddress) + } // Wait for the service to start up ready := false @@ -158,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc if !ready { log.Debug().Msgf("GRPC Service NOT ready") - log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive()) - log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:")) - - log.Debug().Msgf(grpcControlProcess.ExitCode()) - return nil, fmt.Errorf("grpc service not ready") } @@ -189,6 +222,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) backend := strings.ToLower(o.backendString) + + // if an external backend is provided, use it + _, externalBackendExists := o.externalBackends[backend] + if externalBackendExists { + return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o)) + } + switch backend { case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend, MPTBackend, Gpt2Backend, FalconBackend, @@ -209,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { o := NewOptions(opts...) - log.Debug().Msgf("Loading model '%s' greedly", o.modelFile) - // Is this really needed? BackendLoader already does this ml.mu.Lock() if m := ml.checkIsLoaded(o.modelFile); m != nil { @@ -221,16 +259,29 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { ml.mu.Unlock() var err error - for _, b := range AutoLoadBackends { - log.Debug().Msgf("[%s] Attempting to load", b) + // autoload also external backends + allBackendsToAutoLoad := []string{} + allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...) + for _, b := range o.externalBackends { + allBackendsToAutoLoad = append(allBackendsToAutoLoad, b) + } + log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", ")) - model, modelerr := ml.BackendLoader( + for _, b := range allBackendsToAutoLoad { + log.Debug().Msgf("[%s] Attempting to load", b) + options := []Option{ WithBackendString(b), WithModelFile(o.modelFile), WithLoadGRPCLLMModelOpts(o.gRPCOptions), WithThreads(o.threads), WithAssetDir(o.assetDir), - ) + } + + for k, v := range o.externalBackends { + options = append(options, WithExternalBackend(k, v)) + } + + model, modelerr := ml.BackendLoader(options...) if modelerr == nil && model != nil { log.Debug().Msgf("[%s] Loads OK", b) return model, nil diff --git a/pkg/model/options.go b/pkg/model/options.go index 298ebd40..466e9c2f 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -14,10 +14,21 @@ type Options struct { context context.Context gRPCOptions *pb.ModelOptions + + externalBackends map[string]string } type Option func(*Options) +func WithExternalBackend(name string, uri string) Option { + return func(o *Options) { + if o.externalBackends == nil { + o.externalBackends = make(map[string]string) + } + o.externalBackends[name] = uri + } +} + func WithBackendString(backend string) Option { return func(o *Options) { o.backendString = backend diff --git a/pkg/utils/logging.go b/pkg/utils/logging.go new file mode 100644 index 00000000..d69cbf8e --- /dev/null +++ b/pkg/utils/logging.go @@ -0,0 +1,37 @@ +package utils + +import ( + "time" + + "github.com/rs/zerolog/log" +) + +var lastProgress time.Time = time.Now() +var startTime time.Time = time.Now() + +func ResetDownloadTimers() { + lastProgress = time.Now() + startTime = time.Now() +} + +func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) { + currentTime := time.Now() + + if currentTime.Sub(lastProgress) >= 5*time.Second { + + lastProgress = currentTime + + // calculate ETA based on percentage and elapsed time + var eta time.Duration + if percentage > 0 { + elapsed := currentTime.Sub(startTime) + eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) + } + + if total != "" { + log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) + } else { + log.Debug().Msgf("Downloading: %s", current) + } + } +} diff --git a/tests/models_fixtures/grpc.yaml b/tests/models_fixtures/grpc.yaml new file mode 100644 index 00000000..31c406ab --- /dev/null +++ b/tests/models_fixtures/grpc.yaml @@ -0,0 +1,5 @@ +name: code-search-ada-code-001 +backend: huggingface +embeddings: true +parameters: + model: all-MiniLM-L6-v2 \ No newline at end of file