refactor: Minor improvements to BackendConfigLoader (#2353)

some minor renames and refactorings within BackendConfigLoader - make things more consistent, remove underused code, rename things for clarity

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-05-23 16:48:12 -04:00 committed by GitHub
parent 114f549f5e
commit 0b637465d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 116 deletions

View File

@ -360,8 +360,14 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
}
func (c *BackendConfig) Validate() bool {
downloadedFileNames := []string{}
for _, f := range c.DownloadFiles {
downloadedFileNames = append(downloadedFileNames, f.Filename)
}
validationTargets := []string{c.Backend, c.Model, c.MMProj}
validationTargets = append(validationTargets, downloadedFileNames...)
// Simple validation to make sure the model can be correctly loaded
for _, n := range []string{c.Backend, c.Model, c.MMProj} {
for _, n := range validationTargets {
if n == "" {
continue
}

View File

@ -23,6 +23,12 @@ type BackendConfigLoader struct {
sync.Mutex
}
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
type LoadOptions struct {
debug bool
threads, ctxSize int
@ -61,46 +67,8 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
}
}
// Load a config file for a model
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cl.LoadBackendConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}
cfg.SetDefaults(opts...)
return cfg, nil
}
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig
func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
@ -117,7 +85,7 @@ func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendC
return *c, nil
}
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
@ -134,32 +102,67 @@ func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig,
return c, nil
}
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadBackendConfigFile(file, opts...)
// Load a config file for a model
func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := bcl.LoadBackendConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}
cfg.SetDefaults(opts...)
return cfg, nil
}
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readMultipleBackendConfigsFromFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
if cc.Validate() {
cm.configs[cc.Name] = *cc
bcl.configs[cc.Name] = *cc
}
}
return nil
}
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
cl.Lock()
defer cl.Unlock()
c, err := ReadBackendConfig(file, opts...)
func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
if c.Validate() {
cl.configs[c.Name] = *c
bcl.configs[c.Name] = *c
} else {
return fmt.Errorf("config is not valid")
}
@ -167,18 +170,18 @@ func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoad
return nil
}
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
cl.Lock()
defer cl.Unlock()
v, exists := cl.configs[m]
func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
bcl.Lock()
defer bcl.Unlock()
v, exists := bcl.configs[m]
return v, exists
}
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
cl.Lock()
defer cl.Unlock()
func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
for _, v := range cl.configs {
for _, v := range bcl.configs {
res = append(res, v)
}
@ -189,26 +192,16 @@ func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
return res
}
func (cl *BackendConfigLoader) RemoveBackendConfig(m string) {
cl.Lock()
defer cl.Unlock()
delete(cl.configs, m)
}
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
cl.Lock()
defer cl.Unlock()
var res []string
for k := range cl.configs {
res = append(res, k)
}
return res
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
bcl.Lock()
defer bcl.Unlock()
delete(bcl.configs, m)
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cl *BackendConfigLoader) Preload(modelPath string) error {
cl.Lock()
defer cl.Unlock()
func (bcl *BackendConfigLoader) Preload(modelPath string) error {
bcl.Lock()
defer bcl.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
@ -230,7 +223,7 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}
for i, config := range cl.configs {
for i, config := range bcl.configs {
// Download files and verify their SHA
for i, file := range config.DownloadFiles {
@ -259,10 +252,10 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}
cc := cl.configs[i]
cc := bcl.configs[i]
c := &cc
c.PredictionOptions.Model = modelFileName
cl.configs[i] = *c
bcl.configs[i] = *c
}
if config.IsMMProjURL() {
@ -276,22 +269,22 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}
cc := cl.configs[i]
cc := bcl.configs[i]
c := &cc
c.MMProj = modelFileName
cl.configs[i] = *c
bcl.configs[i] = *c
}
if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
if bcl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name))
}
if cl.configs[i].Description != "" {
if bcl.configs[i].Description != "" {
//glamText("**Description**")
glamText(cl.configs[i].Description)
glamText(bcl.configs[i].Description)
}
if cl.configs[i].Usage != "" {
if bcl.configs[i].Usage != "" {
//glamText("**Usage**")
glamText(cl.configs[i].Usage)
glamText(bcl.configs[i].Usage)
}
}
return nil
@ -299,9 +292,9 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
// (non-recursive)
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return fmt.Errorf("cannot read directory '%s': %w", path, err)
@ -320,13 +313,13 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C
strings.HasPrefix(file.Name(), ".") {
continue
}
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err != nil {
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name())
continue
}
if c.Validate() {
cm.configs[c.Name] = *c
bcl.configs[c.Name] = *c
} else {
log.Error().Err(err).Msgf("config is not valid")
}

View File

@ -1,12 +1,10 @@
package config_test
package config
import (
"io"
"net/http"
"os"
. "github.com/go-skynet/LocalAI/core/config"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@ -22,7 +20,7 @@ var _ = Describe("Test cases for config related functions", func() {
parameters:
model: "foo-bar"`)
Expect(err).ToNot(HaveOccurred())
config, err := ReadBackendConfig(tmp.Name())
config, err := readBackendConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
Expect(config.Validate()).To(BeFalse())
@ -37,7 +35,7 @@ backend: "foo-bar"
parameters:
model: "foo-bar"`)
Expect(err).ToNot(HaveOccurred())
config, err := ReadBackendConfig(tmp.Name())
config, err := readBackendConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
@ -54,7 +52,7 @@ parameters:
defer os.Remove(tmp.Name())
_, err = io.Copy(tmp, resp.Body)
Expect(err).To(BeNil())
config, err = ReadBackendConfig(tmp.Name())
config, err = readBackendConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml

View File

@ -1,10 +1,8 @@
package config_test
package config
import (
"os"
. "github.com/go-skynet/LocalAI/core/config"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@ -17,8 +15,8 @@ var _ = Describe("Test cases for config related functions", func() {
Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE")
It("Test ReadConfigFile", func() {
config, err := ReadBackendConfigFile(configFile)
It("Test readConfigFile", func() {
config, err := readMultipleBackendConfigsFromFile(configFile)
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
@ -27,21 +25,28 @@ var _ = Describe("Test cases for config related functions", func() {
})
It("Test LoadConfigs", func() {
cm := NewBackendConfigLoader()
err := cm.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
Expect(err).To(BeNil())
Expect(cm.ListBackendConfigs()).ToNot(BeNil())
Expect(cm.ListBackendConfigs()).To(ContainElements("code-search-ada-code-001"))
bcl := NewBackendConfigLoader()
err := bcl.LoadBackendConfigsFromPath(os.Getenv("MODELS_PATH"))
Expect(err).To(BeNil())
configs := bcl.GetAllBackendConfigs()
loadedModelNames := []string{}
for _, v := range configs {
loadedModelNames = append(loadedModelNames, v.Name)
}
Expect(configs).ToNot(BeNil())
Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001"))
// config should includes text-embedding-ada-002 models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002"))
Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002"))
// config should includes rwkv_test models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test"))
Expect(loadedModelNames).To(ContainElements("rwkv_test"))
// config should includes whisper-1 models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1"))
Expect(loadedModelNames).To(ContainElements("whisper-1"))
})
})
})

View File

@ -72,7 +72,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
}
if options.ConfigFile != "" {
if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil {
if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config file")
}
}
@ -94,9 +94,8 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
}
if options.Debug {
for _, v := range cl.ListBackendConfigs() {
cfg, _ := cl.GetBackendConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
for _, v := range cl.GetAllBackendConfigs() {
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
}
}