diff --git a/core/config/backend_config.go b/core/config/backend_config.go index f0906de3..a4979233 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -360,8 +360,14 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { } func (c *BackendConfig) Validate() bool { + downloadedFileNames := []string{} + for _, f := range c.DownloadFiles { + downloadedFileNames = append(downloadedFileNames, f.Filename) + } + validationTargets := []string{c.Backend, c.Model, c.MMProj} + validationTargets = append(validationTargets, downloadedFileNames...) // Simple validation to make sure the model can be correctly loaded - for _, n := range []string{c.Backend, c.Model, c.MMProj} { + for _, n := range validationTargets { if n == "" { continue } diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go index ce02fc9a..54e33cf8 100644 --- a/core/config/backend_config_loader.go +++ b/core/config/backend_config_loader.go @@ -23,6 +23,12 @@ type BackendConfigLoader struct { sync.Mutex } +func NewBackendConfigLoader() *BackendConfigLoader { + return &BackendConfigLoader{ + configs: make(map[string]BackendConfig), + } +} + type LoadOptions struct { debug bool threads, ctxSize int @@ -61,46 +67,8 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { } } -// Load a config file for a model -func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { - - // Load a config file if present after the model name - cfg := &BackendConfig{ - PredictionOptions: schema.PredictionOptions{ - Model: modelName, - }, - } - - cfgExisting, exists := cl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } else { - // Try loading a model config file - modelConfig := filepath.Join(modelPath, modelName+".yaml") - if _, err := os.Stat(modelConfig); err == nil { - if err := cl.LoadBackendConfig( - modelConfig, opts..., - ); err != nil { - return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfgExisting, exists = cl.GetBackendConfig(modelName) - if exists { - cfg = &cfgExisting - } - } - } - - cfg.SetDefaults(opts...) - - return cfg, nil -} - -func NewBackendConfigLoader() *BackendConfigLoader { - return &BackendConfigLoader{ - configs: make(map[string]BackendConfig), - } -} -func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { +// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig +func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) { c := &[]*BackendConfig{} f, err := os.ReadFile(file) if err != nil { @@ -117,7 +85,7 @@ func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendC return *c, nil } -func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { +func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) { lo := &LoadOptions{} lo.Apply(opts...) @@ -134,32 +102,67 @@ func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, return c, nil } -func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadBackendConfigFile(file, opts...) +// Load a config file for a model +func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) { + + // Load a config file if present after the model name + cfg := &BackendConfig{ + PredictionOptions: schema.PredictionOptions{ + Model: modelName, + }, + } + + cfgExisting, exists := bcl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } else { + // Try loading a model config file + modelConfig := filepath.Join(modelPath, modelName+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := bcl.LoadBackendConfig( + modelConfig, opts..., + ); err != nil { + return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = bcl.GetBackendConfig(modelName) + if exists { + cfg = &cfgExisting + } + } + } + + cfg.SetDefaults(opts...) + + return cfg, nil +} + +// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile +func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error { + bcl.Lock() + defer bcl.Unlock() + c, err := readMultipleBackendConfigsFromFile(file, opts...) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } for _, cc := range c { if cc.Validate() { - cm.configs[cc.Name] = *cc + bcl.configs[cc.Name] = *cc } } return nil } -func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { - cl.Lock() - defer cl.Unlock() - c, err := ReadBackendConfig(file, opts...) +func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error { + bcl.Lock() + defer bcl.Unlock() + c, err := readBackendConfigFromFile(file, opts...) if err != nil { return fmt.Errorf("cannot read config file: %w", err) } if c.Validate() { - cl.configs[c.Name] = *c + bcl.configs[c.Name] = *c } else { return fmt.Errorf("config is not valid") } @@ -167,18 +170,18 @@ func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoad return nil } -func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { - cl.Lock() - defer cl.Unlock() - v, exists := cl.configs[m] +func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) { + bcl.Lock() + defer bcl.Unlock() + v, exists := bcl.configs[m] return v, exists } -func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { - cl.Lock() - defer cl.Unlock() +func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { + bcl.Lock() + defer bcl.Unlock() var res []BackendConfig - for _, v := range cl.configs { + for _, v := range bcl.configs { res = append(res, v) } @@ -189,26 +192,16 @@ func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig { return res } -func (cl *BackendConfigLoader) RemoveBackendConfig(m string) { - cl.Lock() - defer cl.Unlock() - delete(cl.configs, m) -} - -func (cl *BackendConfigLoader) ListBackendConfigs() []string { - cl.Lock() - defer cl.Unlock() - var res []string - for k := range cl.configs { - res = append(res, k) - } - return res +func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) { + bcl.Lock() + defer bcl.Unlock() + delete(bcl.configs, m) } // Preload prepare models if they are not local but url or huggingface repositories -func (cl *BackendConfigLoader) Preload(modelPath string) error { - cl.Lock() - defer cl.Unlock() +func (bcl *BackendConfigLoader) Preload(modelPath string) error { + bcl.Lock() + defer bcl.Unlock() status := func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) @@ -230,7 +223,7 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { } } - for i, config := range cl.configs { + for i, config := range bcl.configs { // Download files and verify their SHA for i, file := range config.DownloadFiles { @@ -259,10 +252,10 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { } } - cc := cl.configs[i] + cc := bcl.configs[i] c := &cc c.PredictionOptions.Model = modelFileName - cl.configs[i] = *c + bcl.configs[i] = *c } if config.IsMMProjURL() { @@ -276,22 +269,22 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { } } - cc := cl.configs[i] + cc := bcl.configs[i] c := &cc c.MMProj = modelFileName - cl.configs[i] = *c + bcl.configs[i] = *c } - if cl.configs[i].Name != "" { - glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) + if bcl.configs[i].Name != "" { + glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name)) } - if cl.configs[i].Description != "" { + if bcl.configs[i].Description != "" { //glamText("**Description**") - glamText(cl.configs[i].Description) + glamText(bcl.configs[i].Description) } - if cl.configs[i].Usage != "" { + if bcl.configs[i].Usage != "" { //glamText("**Usage**") - glamText(cl.configs[i].Usage) + glamText(bcl.configs[i].Usage) } } return nil @@ -299,9 +292,9 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { // LoadBackendConfigsFromPath reads all the configurations of the models from a path // (non-recursive) -func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { - cm.Lock() - defer cm.Unlock() +func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { + bcl.Lock() + defer bcl.Unlock() entries, err := os.ReadDir(path) if err != nil { return fmt.Errorf("cannot read directory '%s': %w", path, err) @@ -320,13 +313,13 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C strings.HasPrefix(file.Name(), ".") { continue } - c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) + c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...) if err != nil { log.Error().Err(err).Msgf("cannot read config file: %s", file.Name()) continue } if c.Validate() { - cm.configs[c.Name] = *c + bcl.configs[c.Name] = *c } else { log.Error().Err(err).Msgf("config is not valid") } diff --git a/core/config/backend_config_test.go b/core/config/backend_config_test.go index 4c1437e3..48bcfa9c 100644 --- a/core/config/backend_config_test.go +++ b/core/config/backend_config_test.go @@ -1,12 +1,10 @@ -package config_test +package config import ( "io" "net/http" "os" - . "github.com/go-skynet/LocalAI/core/config" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -22,7 +20,7 @@ var _ = Describe("Test cases for config related functions", func() { parameters: model: "foo-bar"`) Expect(err).ToNot(HaveOccurred()) - config, err := ReadBackendConfig(tmp.Name()) + config, err := readBackendConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) Expect(config.Validate()).To(BeFalse()) @@ -37,7 +35,7 @@ backend: "foo-bar" parameters: model: "foo-bar"`) Expect(err).ToNot(HaveOccurred()) - config, err := ReadBackendConfig(tmp.Name()) + config, err := readBackendConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -54,7 +52,7 @@ parameters: defer os.Remove(tmp.Name()) _, err = io.Copy(tmp, resp.Body) Expect(err).To(BeNil()) - config, err = ReadBackendConfig(tmp.Name()) + config, err = readBackendConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml diff --git a/core/config/config_test.go b/core/config/config_test.go index eeb5ad89..da8ba09d 100644 --- a/core/config/config_test.go +++ b/core/config/config_test.go @@ -1,10 +1,8 @@ -package config_test +package config import ( "os" - . "github.com/go-skynet/LocalAI/core/config" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -17,8 +15,8 @@ var _ = Describe("Test cases for config related functions", func() { Context("Test Read configuration functions", func() { configFile = os.Getenv("CONFIG_FILE") - It("Test ReadConfigFile", func() { - config, err := ReadBackendConfigFile(configFile) + It("Test readConfigFile", func() { + config, err := readMultipleBackendConfigsFromFile(configFile) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -27,21 +25,28 @@ var _ = Describe("Test cases for config related functions", func() { }) It("Test LoadConfigs", func() { - cm := NewBackendConfigLoader() - err := cm.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH")) - Expect(err).To(BeNil()) - Expect(cm.ListBackendConfigs()).ToNot(BeNil()) - Expect(cm.ListBackendConfigs()).To(ContainElements("code-search-ada-code-001")) + bcl := NewBackendConfigLoader() + err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH")) + + Expect(err).To(BeNil()) + configs := bcl.GetAllBackendConfigs() + loadedModelNames := []string{} + for _, v := range configs { + loadedModelNames = append(loadedModelNames, v.Name) + } + Expect(configs).ToNot(BeNil()) + + Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001")) // config should includes text-embedding-ada-002 models's api.config - Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002")) + Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002")) // config should includes rwkv_test models's api.config - Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test")) + Expect(loadedModelNames).To(ContainElements("rwkv_test")) // config should includes whisper-1 models's api.config - Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1")) + Expect(loadedModelNames).To(ContainElements("whisper-1")) }) }) }) diff --git a/core/startup/startup.go b/core/startup/startup.go index 2c94391c..c337afb1 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -72,7 +72,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode } if options.ConfigFile != "" { - if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil { + if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { log.Error().Err(err).Msg("error loading config file") } } @@ -94,9 +94,8 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode } if options.Debug { - for _, v := range cl.ListBackendConfigs() { - cfg, _ := cl.GetBackendConfig(v) - log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) + for _, v := range cl.GetAllBackendConfigs() { + log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v) } }