fix(gallery): preload from file should by in YAML format (#846)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-07-31 21:13:16 +02:00 committed by GitHub
parent c1fc22e746
commit d603a9cbb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 13 deletions

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
json "github.com/json-iterator/go" json "github.com/json-iterator/go"
"gopkg.in/yaml.v3"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
@ -132,12 +133,37 @@ type galleryModel struct {
ID string `json:"id"` ID string `json:"id"`
} }
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s) dat, err := os.ReadFile(s)
if err != nil { if err != nil {
return err return err
} }
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries) var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
} }
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
@ -147,16 +173,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
return err return err
} }
for _, r := range requests { return processRequests(modelPath, s, cm, galleries, requests)
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
return err
} }
/// Endpoints /// Endpoints

View File

@ -19,6 +19,8 @@ type Gallery struct {
// Installs a model from the gallery (galleryname@modelname) // Installs a model from the gallery (galleryname@modelname)
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
applyModel := func(model *GalleryModel) error { applyModel := func(model *GalleryModel) error {
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
config, err := GetGalleryConfigFromURL(model.URL) config, err := GetGalleryConfigFromURL(model.URL)
if err != nil { if err != nil {
return err return err
@ -51,7 +53,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
model, err := FindGallery(models, name) model, err := FindGallery(models, name)
if err != nil { if err != nil {
return err var err2 error
model, err2 = FindGallery(models, strings.ToLower(name))
if err2 != nil {
return err
}
} }
return applyModel(model) return applyModel(model)
@ -79,7 +85,7 @@ func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath st
name = strings.ReplaceAll(name, string(os.PathSeparator), "__") name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
var model *GalleryModel var model *GalleryModel
for _, m := range models { for _, m := range models {
if name == m.Name { if name == m.Name || name == strings.ToLower(m.Name) {
model = m model = m
} }
} }