diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go index 74833e4d..256be71f 100644 --- a/backend/go/transcribe/transcript.go +++ b/backend/go/transcribe/transcript.go @@ -29,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { - res := schema.Result{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) { + res := schema.TranscriptionResult{} dir, err := os.MkdirTemp("", "whisper") if err != nil { diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go index ac93be01..a9a62d24 100644 --- a/backend/go/transcribe/whisper.go +++ b/backend/go/transcribe/whisper.go @@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error { return err } -func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) { return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) } diff --git a/core/application.go b/core/application.go new file mode 100644 index 00000000..54d3dedf --- /dev/null +++ b/core/application.go @@ -0,0 +1,39 @@ +package core + +import ( + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/model" +) + +// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy +// Perhaps a proper DI system is worth it in the future, but for now keep things simple. +type Application struct { + + // Application-Level Config + ApplicationConfig *config.ApplicationConfig + // ApplicationState *ApplicationState + + // Core Low-Level Services + BackendConfigLoader *config.BackendConfigLoader + ModelLoader *model.ModelLoader + + // Backend Services + // EmbeddingsBackendService *backend.EmbeddingsBackendService + // ImageGenerationBackendService *backend.ImageGenerationBackendService + // LLMBackendService *backend.LLMBackendService + // TranscriptionBackendService *backend.TranscriptionBackendService + // TextToSpeechBackendService *backend.TextToSpeechBackendService + + // LocalAI System Services + BackendMonitorService *services.BackendMonitorService + GalleryService *services.GalleryService + ListModelsService *services.ListModelsService + LocalAIMetricsService *services.LocalAIMetricsService + // OpenAIService *services.OpenAIService +} + +// TODO [NEXT PR?]: Break up ApplicationConfig. +// Migrate over stuff that is not set via config at all - especially runtime stuff +type ApplicationState struct { +} diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 4c3859df..e620bebd 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -11,7 +11,7 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) { +func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(model.WhisperBackend), diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 0d7d0cbf..cb1b7c2a 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -1,23 +1,12 @@ package config import ( - "errors" - "fmt" - "io/fs" "os" - "path/filepath" - "sort" - "strings" - "sync" "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/functions" "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" - "gopkg.in/yaml.v3" - - "github.com/charmbracelet/glamour" ) const ( @@ -140,7 +129,7 @@ type LLMConfig struct { EnforceEager bool `yaml:"enforce_eager"` // vLLM SwapSpace int `yaml:"swap_space"` // vLLM MaxModelLen int `yaml:"max_model_len"` // vLLM - TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM + TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM MMProj string `yaml:"mmproj"` RopeScaling string `yaml:"rope_scaling"` @@ -343,303 +332,3 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.Debug = &trueV } } - -////// Config Loader //////// - -type BackendConfigLoader struct { - configs map[string]BackendConfig - sync.Mutex -} - -type LoadOptions struct { - debug bool - threads, ctxSize int - f16 bool -} - -func LoadOptionDebug(debug bool) ConfigLoaderOption { - return func(o *LoadOptions) { - o.debug = debug - } -} - -func LoadOptionThreads(threads int) ConfigLoaderOption { - return func(o *LoadOptions) { - o.threads = threads - } -} - -func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { - return func(o *LoadOptions) { - o.ctxSize = ctxSize - } -} - -func LoadOptionF16(f16 bool) ConfigLoaderOption { - return func(o *LoadOptions) { - o.f16 = f16 - } -} - -type ConfigLoaderOption func(*LoadOptions) - -func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { - for _, l := range options { - l(lo) - } -} - -// Load a config file for a model -func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - - // Load a config file if present after the model name - cfg := &BackendConfig{ - PredictionOptions: schema.PredictionOptions{ - Model: modelName, - }, - } - - cfgExisting, exists := cl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } else { - // Try loading a model config file - modelConfig := filepath.Join(modelPath, modelName+".yaml") - if _, err := os.Stat(modelConfig); err == nil { - if err := cl.LoadBackendConfig( - modelConfig, opts..., - ); err != nil { - return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfgExisting, exists = cl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } - } - } - - cfg.SetDefaults(opts...) - - return cfg, nil -} - -func NewBackendConfigLoader() *BackendConfigLoader { - return &BackendConfigLoader{ - configs: make(map[string]BackendConfig), - } -} -func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { - c := &[]*BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - for _, cc := range *c { - cc.SetDefaults(opts...) - } - - return *c, nil -} - -func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - lo := &LoadOptions{} - lo.Apply(opts...) - - c := &BackendConfig{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - c.SetDefaults(opts...) - return c, nil -} - -func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadBackendConfigFile(file, opts...) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - cm.configs[cc.Name] = *cc - } - return nil -} - -func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { - cl.Lock() - defer cl.Unlock() - c, err := ReadBackendConfig(file, opts...) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - cl.configs[c.Name] = *c - return nil -} - -func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { - cl.Lock() - defer cl.Unlock() - v, exists := cl.configs[m] - return v, exists -} - -func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { - cl.Lock() - defer cl.Unlock() - var res []BackendConfig - for _, v := range cl.configs { - res = append(res, v) - } - - sort.SliceStable(res, func(i, j int) bool { - return res[i].Name < res[j].Name - }) - - return res -} - -func (cl *BackendConfigLoader) ListBackendConfigs() []string { - cl.Lock() - defer cl.Unlock() - var res []string - for k := range cl.configs { - res = append(res, k) - } - return res -} - -// Preload prepare models if they are not local but url or huggingface repositories -func (cl *BackendConfigLoader) Preload(modelPath string) error { - cl.Lock() - defer cl.Unlock() - - status := func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - } - - log.Info().Msgf("Preloading models from %s", modelPath) - - renderMode := "dark" - if os.Getenv("COLOR") != "" { - renderMode = os.Getenv("COLOR") - } - - glamText := func(t string) { - out, err := glamour.Render(t, renderMode) - if err == nil && os.Getenv("NO_COLOR") == "" { - fmt.Println(out) - } else { - fmt.Println(t) - } - } - - for i, config := range cl.configs { - - // Download files and verify their SHA - for i, file := range config.DownloadFiles { - log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) - - if err := utils.VerifyPath(file.Filename, modelPath); err != nil { - return err - } - // Create file path - filePath := filepath.Join(modelPath, file.Filename) - - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil { - return err - } - } - - // If the model is an URL, expand it, and download the file - if config.IsModelURL() { - modelFileName := config.ModelFileName() - modelURL := downloader.ConvertURL(config.Model) - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) - if err != nil { - return err - } - } - - cc := cl.configs[i] - c := &cc - c.PredictionOptions.Model = modelFileName - cl.configs[i] = *c - } - - if config.IsMMProjURL() { - modelFileName := config.MMProjFileName() - modelURL := downloader.ConvertURL(config.MMProj) - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) - if err != nil { - return err - } - } - - cc := cl.configs[i] - c := &cc - c.MMProj = modelFileName - cl.configs[i] = *c - } - - if cl.configs[i].Name != "" { - glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) - } - if cl.configs[i].Description != "" { - //glamText("**Description**") - glamText(cl.configs[i].Description) - } - if cl.configs[i].Usage != "" { - //glamText("**Usage**") - glamText(cl.configs[i].Usage) - } - } - return nil -} - -// LoadBackendConfigsFromPath reads all the configurations of the models from a path -// (non-recursive) -func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { - cm.Lock() - defer cm.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") || - strings.HasPrefix(file.Name(), ".") { - continue - } - c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) - if err == nil { - cm.configs[c.Name] = *c - } - } - - return nil -} diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go new file mode 100644 index 00000000..83b66740 --- /dev/null +++ b/core/config/backend_config_loader.go @@ -0,0 +1,317 @@ +package config + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/charmbracelet/glamour" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" + "gopkg.in/yaml.v3" +) + +type BackendConfigLoader struct { + configs map[string]BackendConfig + sync.Mutex +} + +type LoadOptions struct { + debug bool + threads, ctxSize int + f16 bool +} + +func LoadOptionDebug(debug bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.debug = debug + } +} + +func LoadOptionThreads(threads int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.threads = threads + } +} + +func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { + return func(o *LoadOptions) { + o.ctxSize = ctxSize + } +} + +func LoadOptionF16(f16 bool) ConfigLoaderOption { + return func(o *LoadOptions) { + o.f16 = f16 + } +} + +type ConfigLoaderOption func(*LoadOptions) + +func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { + for _, l := range options { + l(lo) + } +} + +// Load a config file for a model +func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + + // Load a config file if present after the model name + cfg := &BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Model: modelName, + }, + } + + cfgExisting, exists := cl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } else { + // Try loading a model config file + modelConfig := filepath.Join(modelPath, modelName+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := cl.LoadBackendConfig( + modelConfig, opts..., + ); err != nil { + return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } + } + } + + cfg.SetDefaults(opts...) + + return cfg, nil +} + +func NewBackendConfigLoader() *BackendConfigLoader { + return &BackendConfigLoader{ + configs: make(map[string]BackendConfig), + } +} +func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { + c := &[]*BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + for _, cc := range *c { + cc.SetDefaults(opts...) + } + + return *c, nil +} + +func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + lo := &LoadOptions{} + lo.Apply(opts...) + + c := &BackendConfig{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + c.SetDefaults(opts...) + return c, nil +} + +func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadBackendConfigFile(file, opts...) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cm.configs[cc.Name] = *cc + } + return nil +} + +func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { + cl.Lock() + defer cl.Unlock() + c, err := ReadBackendConfig(file, opts...) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cl.configs[c.Name] = *c + return nil +} + +func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { + cl.Lock() + defer cl.Unlock() + v, exists := cl.configs[m] + return v, exists +} + +func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { + cl.Lock() + defer cl.Unlock() + var res []BackendConfig + for _, v := range cl.configs { + res = append(res, v) + } + + sort.SliceStable(res, func(i, j int) bool { + return res[i].Name < res[j].Name + }) + + return res +} + +func (cl *BackendConfigLoader) ListBackendConfigs() []string { + cl.Lock() + defer cl.Unlock() + var res []string + for k := range cl.configs { + res = append(res, k) + } + return res +} + +// Preload prepare models if they are not local but url or huggingface repositories +func (cl *BackendConfigLoader) Preload(modelPath string) error { + cl.Lock() + defer cl.Unlock() + + status := func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + } + + log.Info().Msgf("Preloading models from %s", modelPath) + + renderMode := "dark" + if os.Getenv("COLOR") != "" { + renderMode = os.Getenv("COLOR") + } + + glamText := func(t string) { + out, err := glamour.Render(t, renderMode) + if err == nil && os.Getenv("NO_COLOR") == "" { + fmt.Println(out) + } else { + fmt.Println(t) + } + } + + for i, config := range cl.configs { + + // Download files and verify their SHA + for i, file := range config.DownloadFiles { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + if err := utils.VerifyPath(file.Filename, modelPath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(modelPath, file.Filename) + + if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil { + return err + } + } + + // If the model is an URL, expand it, and download the file + if config.IsModelURL() { + modelFileName := config.ModelFileName() + modelURL := downloader.ConvertURL(config.Model) + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) + if err != nil { + return err + } + } + + cc := cl.configs[i] + c := &cc + c.PredictionOptions.Model = modelFileName + cl.configs[i] = *c + } + + if config.IsMMProjURL() { + modelFileName := config.MMProjFileName() + modelURL := downloader.ConvertURL(config.MMProj) + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) + if err != nil { + return err + } + } + + cc := cl.configs[i] + c := &cc + c.MMProj = modelFileName + cl.configs[i] = *c + } + + if cl.configs[i].Name != "" { + glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) + } + if cl.configs[i].Description != "" { + //glamText("**Description**") + glamText(cl.configs[i].Description) + } + if cl.configs[i].Usage != "" { + //glamText("**Usage**") + glamText(cl.configs[i].Usage) + } + } + return nil +} + +// LoadBackendConfigsFromPath reads all the configurations of the models from a path +// (non-recursive) +func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") || + strings.HasPrefix(file.Name(), ".") { + continue + } + c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) + if err == nil { + cm.configs[c.Name] = *c + } + } + + return nil +} diff --git a/core/http/app.go b/core/http/app.go index bd740410..080535a4 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -1,9 +1,7 @@ package http import ( - "encoding/json" "errors" - "os" "strings" "github.com/go-skynet/LocalAI/pkg/utils" @@ -124,20 +122,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi return c.Next() } - // Check for api_keys.json file - fileContent, err := os.ReadFile("api_keys.json") - if err == nil { - // Parse JSON content from the file - var fileKeys []string - err := json.Unmarshal(fileContent, &fileKeys) - if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) - } - - // Add file keys to options.ApiKeys - appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) - } - if len(appConfig.ApiKeys) == 0 { return c.Next() } @@ -174,13 +158,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi app.Use(c) } - // Make sure directories exists - os.MkdirAll(appConfig.ImageDir, 0750) - os.MkdirAll(appConfig.AudioDir, 0750) - os.MkdirAll(appConfig.UploadDir, 0750) - os.MkdirAll(appConfig.ConfigsDir, 0750) - os.MkdirAll(appConfig.ModelPath, 0750) - // Load config jsons utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go index 8c7a664a..dac20388 100644 --- a/core/http/endpoints/localai/backend_monitor.go +++ b/core/http/endpoints/localai/backend_monitor.go @@ -6,7 +6,7 @@ import ( "github.com/gofiber/fiber/v2" ) -func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { +func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) @@ -23,7 +23,7 @@ func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error } } -func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { +func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.BackendMonitorRequest) // Get input data from the request body diff --git a/core/http/endpoints/openai/list.go b/core/http/endpoints/openai/list.go index 04e611a2..2caea96b 100644 --- a/core/http/endpoints/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -1,63 +1,23 @@ package openai import ( - "regexp" - - "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/schema" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/core/services" "github.com/gofiber/fiber/v2" ) -func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { - models, err := ml.ListModels() - if err != nil { - return err - } - var mm map[string]interface{} = map[string]interface{}{} - - dataModels := []schema.OpenAIModel{} - - var filterFn func(name string) bool + // If blank, no filter is applied. filter := c.Query("filter") - // If filter is not specified, do not filter the list by model name - if filter == "" { - filterFn = func(_ string) bool { return true } - } else { - // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn - rxp, err := regexp.Compile(filter) - if err != nil { - return err - } - filterFn = func(name string) bool { - return rxp.MatchString(name) - } - } - // By default, exclude any loose files that are already referenced by a configuration file. excludeConfigured := c.QueryBool("excludeConfigured", true) - // Start with the known configurations - for _, c := range cl.GetAllBackendConfigs() { - if excludeConfigured { - mm[c.Model] = nil - } - - if filterFn(c.Name) { - dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) - } + dataModels, err := lms.ListModels(filter, excludeConfigured) + if err != nil { + return err } - - // Then iterate through the loose files: - for _, m := range models { - // And only adds them if they shouldn't be skipped. - if _, exists := mm[m]; !exists && filterFn(m) { - dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) - } - } - return c.JSON(struct { Object string `json:"object"` Data []schema.OpenAIModel `json:"data"` diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 138babbe..a5099d60 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -52,9 +52,9 @@ func RegisterLocalAIRoutes(app *fiber.App, app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint()) // Experimental Backend Statistics Module - backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now - app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitor)) - app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitor)) + backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now + app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService)) + app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService)) app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index c51ccdcb..74f20175 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -4,6 +4,7 @@ import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/http/endpoints/localai" "github.com/go-skynet/LocalAI/core/http/endpoints/openai" + "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) @@ -81,6 +82,7 @@ func RegisterOpenAIRoutes(app *fiber.App, } // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) - app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) + tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance. + app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS)) + app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS)) } diff --git a/core/schema/whisper.go b/core/schema/transcription.go similarity index 90% rename from core/schema/whisper.go rename to core/schema/transcription.go index 41413c1f..fe1799fa 100644 --- a/core/schema/whisper.go +++ b/core/schema/transcription.go @@ -10,7 +10,7 @@ type Segment struct { Tokens []int `json:"tokens"` } -type Result struct { +type TranscriptionResult struct { Segments []Segment `json:"segments"` Text string `json:"text"` } diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go index 979a67a3..4e993ed9 100644 --- a/core/services/backend_monitor.go +++ b/core/services/backend_monitor.go @@ -15,22 +15,22 @@ import ( gopsutil "github.com/shirou/gopsutil/v3/process" ) -type BackendMonitor struct { - configLoader *config.BackendConfigLoader - modelLoader *model.ModelLoader - options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. +type BackendMonitorService struct { + backendConfigLoader *config.BackendConfigLoader + modelLoader *model.ModelLoader + options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. } -func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor { - return BackendMonitor{ - configLoader: configLoader, - modelLoader: modelLoader, - options: appConfig, +func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService { + return &BackendMonitorService{ + modelLoader: modelLoader, + backendConfigLoader: configLoader, + options: appConfig, } } -func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { - config, exists := bm.configLoader.GetBackendConfig(modelName) +func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) { + config, exists := bms.backendConfigLoader.GetBackendConfig(modelName) var backendId string if exists { backendId = config.Model @@ -46,8 +46,8 @@ func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string return backendId, nil } -func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { - config, exists := bm.configLoader.GetBackendConfig(model) +func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { + config, exists := bms.backendConfigLoader.GetBackendConfig(model) var backend string if exists { backend = config.Model @@ -60,7 +60,7 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe backend = fmt.Sprintf("%s.bin", backend) } - pid, err := bm.modelLoader.GetGRPCPID(backend) + pid, err := bms.modelLoader.GetGRPCPID(backend) if err != nil { log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid") @@ -101,12 +101,12 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe }, nil } -func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { - backendId, err := bm.getModelLoaderIDFromModelName(modelName) +func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) { + backendId, err := bms.getModelLoaderIDFromModelName(modelName) if err != nil { return nil, err } - modelAddr := bm.modelLoader.CheckIsLoaded(backendId) + modelAddr := bms.modelLoader.CheckIsLoaded(backendId) if modelAddr == "" { return nil, fmt.Errorf("backend %s is not currently loaded", backendId) } @@ -114,7 +114,7 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) if rpcErr != nil { log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) - val, slbErr := bm.SampleLocalBackendProcess(backendId) + val, slbErr := bms.SampleLocalBackendProcess(backendId) if slbErr != nil { return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) } @@ -131,10 +131,10 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse return status, nil } -func (bm BackendMonitor) ShutdownModel(modelName string) error { - backendId, err := bm.getModelLoaderIDFromModelName(modelName) +func (bms BackendMonitorService) ShutdownModel(modelName string) error { + backendId, err := bms.getModelLoaderIDFromModelName(modelName) if err != nil { return err } - return bm.modelLoader.ShutdownModel(backendId) + return bms.modelLoader.ShutdownModel(backendId) } diff --git a/core/services/list_models.go b/core/services/list_models.go new file mode 100644 index 00000000..a21e6faf --- /dev/null +++ b/core/services/list_models.go @@ -0,0 +1,72 @@ +package services + +import ( + "regexp" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/model" +) + +type ListModelsService struct { + bcl *config.BackendConfigLoader + ml *model.ModelLoader + appConfig *config.ApplicationConfig +} + +func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService { + return &ListModelsService{ + bcl: bcl, + ml: ml, + appConfig: appConfig, + } +} + +func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) { + + models, err := lms.ml.ListModels() + if err != nil { + return nil, err + } + + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []schema.OpenAIModel{} + + var filterFn func(name string) bool + + // If filter is not specified, do not filter the list by model name + if filter == "" { + filterFn = func(_ string) bool { return true } + } else { + // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn + rxp, err := regexp.Compile(filter) + if err != nil { + return nil, err + } + filterFn = func(name string) bool { + return rxp.MatchString(name) + } + } + + // Start with the known configurations + for _, c := range lms.bcl.GetAllBackendConfigs() { + if excludeConfigured { + mm[c.Model] = nil + } + + if filterFn(c.Name) { + dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) + } + } + + // Then iterate through the loose files: + for _, m := range models { + // And only adds them if they shouldn't be skipped. + if _, exists := mm[m]; !exists && filterFn(m) { + dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) + } + } + + return dataModels, nil +} diff --git a/core/startup/startup.go b/core/startup/startup.go index e5660f4c..672aee15 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -4,6 +4,7 @@ import ( "fmt" "os" + "github.com/go-skynet/LocalAI/core" "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/internal" @@ -133,3 +134,33 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode log.Info().Msg("core/startup process completed!") return cl, ml, options, nil } + +// In Lieu of a proper DI framework, this function wires up the Application manually. +// This is in core/startup rather than core/state.go to keep package references clean! +func createApplication(appConfig *config.ApplicationConfig) *core.Application { + app := &core.Application{ + ApplicationConfig: appConfig, + BackendConfigLoader: config.NewBackendConfigLoader(), + ModelLoader: model.NewModelLoader(appConfig.ModelPath), + } + + var err error + + // app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + // app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + // app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + // app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + // app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + + app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + app.GalleryService = services.NewGalleryService(app.ApplicationConfig.ModelPath) + app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig) + // app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService) + + app.LocalAIMetricsService, err = services.NewLocalAIMetricsService() + if err != nil { + log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.") + } + + return app +} diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index bef9e186..b5745db5 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -41,7 +41,7 @@ type Backend interface { PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) - AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) + AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 0af5d94f..c0b4bc34 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -53,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { - return schema.Result{}, fmt.Errorf("unimplemented") +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) { + return schema.TranscriptionResult{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index fc4a12fa..06ccc1b4 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -210,7 +210,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -231,7 +231,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques if err != nil { return nil, err } - tresult := &schema.Result{} + tresult := &schema.TranscriptionResult{} for _, s := range res.Segments { tks := []int{} for _, t := range s.Tokens { diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 694e83b0..d2038759 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -53,12 +53,12 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc. return e.s.TTS(ctx, in) } -func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { +func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) { r, err := e.s.AudioTranscription(ctx, in) if err != nil { return nil, err } - tr := &schema.Result{} + tr := &schema.TranscriptionResult{} for _, s := range r.Segments { var tks []int for _, t := range s.Tokens { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 4d06544d..aa7a3fbc 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -15,7 +15,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) + AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) TTS(*pb.TTSRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/utils/base64.go b/pkg/utils/base64.go new file mode 100644 index 00000000..977156e9 --- /dev/null +++ b/pkg/utils/base64.go @@ -0,0 +1,50 @@ +package utils + +import ( + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +var base64DownloadClient http.Client = http.Client{ + Timeout: 30 * time.Second, +} + +// this function check if the string is an URL, if it's an URL downloads the image in memory +// encodes it in base64 and returns the base64 string + +// This may look weird down in pkg/utils while it is currently only used in core/config +// +// but I believe it may be useful for MQTT as well in the near future, so I'm +// extracting it while I'm thinking of it. +func GetImageURLAsBase64(s string) (string, error) { + if strings.HasPrefix(s, "http") { + // download the image + resp, err := base64DownloadClient.Get(s) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // read the image data into memory + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // encode the image data in base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // return the base64 string + return encoded, nil + } + + // if the string instead is prefixed with "data:image/jpeg;base64,", drop it + if strings.HasPrefix(s, "data:image/jpeg;base64,") { + return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + } + return "", fmt.Errorf("not valid string") +} \ No newline at end of file diff --git a/pkg/utils/base64_test.go b/pkg/utils/base64_test.go new file mode 100644 index 00000000..28a09d17 --- /dev/null +++ b/pkg/utils/base64_test.go @@ -0,0 +1,31 @@ +package utils_test + +import ( + . "github.com/go-skynet/LocalAI/pkg/utils" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("utils/base64 tests", func() { + It("GetImageURLAsBase64 can strip data url prefixes", func() { + // This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes. + input := "data:image/jpeg;base64,FOO" + b64, err := GetImageURLAsBase64(input) + Expect(err).To(BeNil()) + Expect(b64).To(Equal("FOO")) + }) + It("GetImageURLAsBase64 returns an error for bogus data", func() { + input := "FOO" + b64, err := GetImageURLAsBase64(input) + Expect(b64).To(Equal("")) + Expect(err).ToNot(BeNil()) + Expect(err).To(MatchError("not valid string")) + }) + It("GetImageURLAsBase64 can actually download images and calculates something", func() { + // This test doesn't actually _check_ the results at this time, which is bad, but there wasn't a test at all before... + input := "https://upload.wikimedia.org/wikipedia/en/2/29/Wargames.jpg" + b64, err := GetImageURLAsBase64(input) + Expect(err).To(BeNil()) + Expect(b64).ToNot(BeNil()) + }) +})