LocalAI/core/services/gallery.go
Dave ab7b4d5ee9
[Refactor]: Core/API Split (#1506)
Refactors api folder to core, creates firm split between backend code and api frontend.
2024-01-05 15:34:56 +01:00

161 lines
4.5 KiB
Go

package services
import (
"context"
"encoding/json"
"os"
"strings"
"sync"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v2"
)
type GalleryApplier struct {
modelPath string
sync.Mutex
C chan gallery.GalleryOp
statuses map[string]*gallery.GalleryOpStatus
}
func NewGalleryApplier(modelPath string) *GalleryApplier {
return &GalleryApplier{
modelPath: modelPath,
C: make(chan gallery.GalleryOp),
statuses: make(map[string]*gallery.GalleryOpStatus),
}
}
func (g *GalleryApplier) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *GalleryApplier) GetStatus(s string) *gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *GalleryApplier) GetAllStatus() map[string]*gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *GalleryApplier) Start(c context.Context, cm *ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" {
if strings.Contains(op.GalleryName, "@") {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
}
} else {
err = PrepareModel(g.modelPath, op.Req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func PrepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetInstallableModelFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func processRequests(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = PrepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}