diff --git a/api/localai/gallery.go b/api/localai/gallery.go index ef4be145..b4180ada 100644 --- a/api/localai/gallery.go +++ b/api/localai/gallery.go @@ -8,6 +8,7 @@ import ( "sync" json "github.com/json-iterator/go" + "gopkg.in/yaml.v3" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/gallery" @@ -132,12 +133,37 @@ type galleryModel struct { ID string `json:"id"` } +func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { + var err error + for _, r := range requests { + utils.ResetDownloadTimers() + if r.ID == "" { + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) + } else { + if strings.Contains(r.ID, "@") { + err = gallery.InstallModelFromGallery( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } else { + err = gallery.InstallModelFromGalleryByName( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } + } + } + return err +} + func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { dat, err := os.ReadFile(s) if err != nil { return err } - return ApplyGalleryFromString(modelPath, string(dat), cm, galleries) + var requests []galleryModel + + if err := yaml.Unmarshal(dat, &requests); err != nil { + return err + } + + return processRequests(modelPath, s, cm, galleries, requests) } func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { @@ -147,16 +173,7 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler return err } - for _, r := range requests { - utils.ResetDownloadTimers() - if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) - } else { - err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } - } - - return err + return processRequests(modelPath, s, cm, galleries, requests) } /// Endpoints diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index 6fe05ed9..f9a14002 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -19,6 +19,8 @@ type Gallery struct { // Installs a model from the gallery (galleryname@modelname) func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { applyModel := func(model *GalleryModel) error { + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + config, err := GetGalleryConfigFromURL(model.URL) if err != nil { return err @@ -51,7 +53,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, model, err := FindGallery(models, name) if err != nil { - return err + var err2 error + model, err2 = FindGallery(models, strings.ToLower(name)) + if err2 != nil { + return err + } } return applyModel(model) @@ -79,7 +85,7 @@ func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath st name = strings.ReplaceAll(name, string(os.PathSeparator), "__") var model *GalleryModel for _, m := range models { - if name == m.Name { + if name == m.Name || name == strings.ToLower(m.Name) { model = m } }