mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
fix(gallery): preload from file should by in YAML format (#846)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
c1fc22e746
commit
d603a9cbb5
@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
|
||||
json "github.com/json-iterator/go"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
@ -132,12 +133,37 @@ type galleryModel struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||
var err error
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
if strings.Contains(r.ID, "@") {
|
||||
err = gallery.InstallModelFromGallery(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||
dat, err := os.ReadFile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries)
|
||||
var requests []galleryModel
|
||||
|
||||
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
||||
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||
@ -147,16 +173,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
||||
|
||||
/// Endpoints
|
||||
|
@ -19,6 +19,8 @@ type Gallery struct {
|
||||
// Installs a model from the gallery (galleryname@modelname)
|
||||
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
config, err := GetGalleryConfigFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -51,7 +53,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
||||
|
||||
model, err := FindGallery(models, name)
|
||||
if err != nil {
|
||||
return err
|
||||
var err2 error
|
||||
model, err2 = FindGallery(models, strings.ToLower(name))
|
||||
if err2 != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return applyModel(model)
|
||||
@ -79,7 +85,7 @@ func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath st
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
var model *GalleryModel
|
||||
for _, m := range models {
|
||||
if name == m.Name {
|
||||
if name == m.Name || name == strings.ToLower(m.Name) {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user