mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
refactor: Minor improvements to BackendConfigLoader (#2353)
some minor renames and refactorings within BackendConfigLoader - make things more consistent, remove underused code, rename things for clarity Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
parent
114f549f5e
commit
0b637465d9
@ -360,8 +360,14 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
}
|
||||
|
||||
func (c *BackendConfig) Validate() bool {
|
||||
downloadedFileNames := []string{}
|
||||
for _, f := range c.DownloadFiles {
|
||||
downloadedFileNames = append(downloadedFileNames, f.Filename)
|
||||
}
|
||||
validationTargets := []string{c.Backend, c.Model, c.MMProj}
|
||||
validationTargets = append(validationTargets, downloadedFileNames...)
|
||||
// Simple validation to make sure the model can be correctly loaded
|
||||
for _, n := range []string{c.Backend, c.Model, c.MMProj} {
|
||||
for _, n := range validationTargets {
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
|
@ -23,6 +23,12 @@ type BackendConfigLoader struct {
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewBackendConfigLoader() *BackendConfigLoader {
|
||||
return &BackendConfigLoader{
|
||||
configs: make(map[string]BackendConfig),
|
||||
}
|
||||
}
|
||||
|
||||
type LoadOptions struct {
|
||||
debug bool
|
||||
threads, ctxSize int
|
||||
@ -61,46 +67,8 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
|
||||
}
|
||||
}
|
||||
|
||||
// Load a config file for a model
|
||||
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||
|
||||
// Load a config file if present after the model name
|
||||
cfg := &BackendConfig{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
Model: modelName,
|
||||
},
|
||||
}
|
||||
|
||||
cfgExisting, exists := cl.GetBackendConfig(modelName)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
} else {
|
||||
// Try loading a model config file
|
||||
modelConfig := filepath.Join(modelPath, modelName+".yaml")
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := cl.LoadBackendConfig(
|
||||
modelConfig, opts...,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
cfgExisting, exists = cl.GetBackendConfig(modelName)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg.SetDefaults(opts...)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func NewBackendConfigLoader() *BackendConfigLoader {
|
||||
return &BackendConfigLoader{
|
||||
configs: make(map[string]BackendConfig),
|
||||
}
|
||||
}
|
||||
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
|
||||
// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig
|
||||
func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
|
||||
c := &[]*BackendConfig{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
@ -117,7 +85,7 @@ func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendC
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||
func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||
lo := &LoadOptions{}
|
||||
lo.Apply(opts...)
|
||||
|
||||
@ -134,32 +102,67 @@ func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig,
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadBackendConfigFile(file, opts...)
|
||||
// Load a config file for a model
|
||||
func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
|
||||
|
||||
// Load a config file if present after the model name
|
||||
cfg := &BackendConfig{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
Model: modelName,
|
||||
},
|
||||
}
|
||||
|
||||
cfgExisting, exists := bcl.GetBackendConfig(modelName)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
} else {
|
||||
// Try loading a model config file
|
||||
modelConfig := filepath.Join(modelPath, modelName+".yaml")
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := bcl.LoadBackendConfig(
|
||||
modelConfig, opts...,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
cfgExisting, exists = bcl.GetBackendConfig(modelName)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg.SetDefaults(opts...)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
|
||||
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
c, err := readMultipleBackendConfigsFromFile(file, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
}
|
||||
|
||||
for _, cc := range c {
|
||||
if cc.Validate() {
|
||||
cm.configs[cc.Name] = *cc
|
||||
bcl.configs[cc.Name] = *cc
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
c, err := ReadBackendConfig(file, opts...)
|
||||
func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
c, err := readBackendConfigFromFile(file, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
|
||||
if c.Validate() {
|
||||
cl.configs[c.Name] = *c
|
||||
bcl.configs[c.Name] = *c
|
||||
} else {
|
||||
return fmt.Errorf("config is not valid")
|
||||
}
|
||||
@ -167,18 +170,18 @@ func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoad
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
v, exists := cl.configs[m]
|
||||
func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
v, exists := bcl.configs[m]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
var res []BackendConfig
|
||||
for _, v := range cl.configs {
|
||||
for _, v := range bcl.configs {
|
||||
res = append(res, v)
|
||||
}
|
||||
|
||||
@ -189,26 +192,16 @@ func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
|
||||
return res
|
||||
}
|
||||
|
||||
func (cl *BackendConfigLoader) RemoveBackendConfig(m string) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
delete(cl.configs, m)
|
||||
}
|
||||
|
||||
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
var res []string
|
||||
for k := range cl.configs {
|
||||
res = append(res, k)
|
||||
}
|
||||
return res
|
||||
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
delete(bcl.configs, m)
|
||||
}
|
||||
|
||||
// Preload prepare models if they are not local but url or huggingface repositories
|
||||
func (cl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
func (bcl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
|
||||
status := func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
@ -230,7 +223,7 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
}
|
||||
}
|
||||
|
||||
for i, config := range cl.configs {
|
||||
for i, config := range bcl.configs {
|
||||
|
||||
// Download files and verify their SHA
|
||||
for i, file := range config.DownloadFiles {
|
||||
@ -259,10 +252,10 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
}
|
||||
}
|
||||
|
||||
cc := cl.configs[i]
|
||||
cc := bcl.configs[i]
|
||||
c := &cc
|
||||
c.PredictionOptions.Model = modelFileName
|
||||
cl.configs[i] = *c
|
||||
bcl.configs[i] = *c
|
||||
}
|
||||
|
||||
if config.IsMMProjURL() {
|
||||
@ -276,22 +269,22 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
}
|
||||
}
|
||||
|
||||
cc := cl.configs[i]
|
||||
cc := bcl.configs[i]
|
||||
c := &cc
|
||||
c.MMProj = modelFileName
|
||||
cl.configs[i] = *c
|
||||
bcl.configs[i] = *c
|
||||
}
|
||||
|
||||
if cl.configs[i].Name != "" {
|
||||
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
|
||||
if bcl.configs[i].Name != "" {
|
||||
glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name))
|
||||
}
|
||||
if cl.configs[i].Description != "" {
|
||||
if bcl.configs[i].Description != "" {
|
||||
//glamText("**Description**")
|
||||
glamText(cl.configs[i].Description)
|
||||
glamText(bcl.configs[i].Description)
|
||||
}
|
||||
if cl.configs[i].Usage != "" {
|
||||
if bcl.configs[i].Usage != "" {
|
||||
//glamText("**Usage**")
|
||||
glamText(cl.configs[i].Usage)
|
||||
glamText(bcl.configs[i].Usage)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -299,9 +292,9 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
|
||||
|
||||
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
|
||||
// (non-recursive)
|
||||
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read directory '%s': %w", path, err)
|
||||
@ -320,13 +313,13 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C
|
||||
strings.HasPrefix(file.Name(), ".") {
|
||||
continue
|
||||
}
|
||||
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
|
||||
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name())
|
||||
continue
|
||||
}
|
||||
if c.Validate() {
|
||||
cm.configs[c.Name] = *c
|
||||
bcl.configs[c.Name] = *c
|
||||
} else {
|
||||
log.Error().Err(err).Msgf("config is not valid")
|
||||
}
|
||||
|
@ -1,12 +1,10 @@
|
||||
package config_test
|
||||
package config
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@ -22,7 +20,7 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
parameters:
|
||||
model: "foo-bar"`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := ReadBackendConfig(tmp.Name())
|
||||
config, err := readBackendConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
Expect(config.Validate()).To(BeFalse())
|
||||
@ -37,7 +35,7 @@ backend: "foo-bar"
|
||||
parameters:
|
||||
model: "foo-bar"`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
config, err := ReadBackendConfig(tmp.Name())
|
||||
config, err := readBackendConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@ -54,7 +52,7 @@ parameters:
|
||||
defer os.Remove(tmp.Name())
|
||||
_, err = io.Copy(tmp, resp.Body)
|
||||
Expect(err).To(BeNil())
|
||||
config, err = ReadBackendConfig(tmp.Name())
|
||||
config, err = readBackendConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
|
@ -1,10 +1,8 @@
|
||||
package config_test
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/core/config"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@ -17,8 +15,8 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
|
||||
Context("Test Read configuration functions", func() {
|
||||
configFile = os.Getenv("CONFIG_FILE")
|
||||
It("Test ReadConfigFile", func() {
|
||||
config, err := ReadBackendConfigFile(configFile)
|
||||
It("Test readConfigFile", func() {
|
||||
config, err := readMultipleBackendConfigsFromFile(configFile)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@ -27,21 +25,28 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
})
|
||||
|
||||
It("Test LoadConfigs", func() {
|
||||
cm := NewBackendConfigLoader()
|
||||
err := cm.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(cm.ListBackendConfigs()).ToNot(BeNil())
|
||||
|
||||
Expect(cm.ListBackendConfigs()).To(ContainElements("code-search-ada-code-001"))
|
||||
bcl := NewBackendConfigLoader()
|
||||
err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
configs := bcl.GetAllBackendConfigs()
|
||||
loadedModelNames := []string{}
|
||||
for _, v := range configs {
|
||||
loadedModelNames = append(loadedModelNames, v.Name)
|
||||
}
|
||||
Expect(configs).ToNot(BeNil())
|
||||
|
||||
Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001"))
|
||||
|
||||
// config should includes text-embedding-ada-002 models's api.config
|
||||
Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002"))
|
||||
Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002"))
|
||||
|
||||
// config should includes rwkv_test models's api.config
|
||||
Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test"))
|
||||
Expect(loadedModelNames).To(ContainElements("rwkv_test"))
|
||||
|
||||
// config should includes whisper-1 models's api.config
|
||||
Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1"))
|
||||
Expect(loadedModelNames).To(ContainElements("whisper-1"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -72,7 +72,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
|
||||
log.Error().Err(err).Msg("error loading config file")
|
||||
}
|
||||
}
|
||||
@ -94,9 +94,8 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
for _, v := range cl.ListBackendConfigs() {
|
||||
cfg, _ := cl.GetBackendConfig(v)
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||
for _, v := range cl.GetAllBackendConfigs() {
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user