diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 35e0776d..0d7d0cbf 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -184,6 +184,36 @@ func (c *BackendConfig) ShouldCallSpecificFunction() bool { return len(c.functionCallNameString) > 0 } +// MMProjFileName returns the filename of the MMProj file +// If the MMProj is a URL, it will return the MD5 of the URL which is the filename +func (c *BackendConfig) MMProjFileName() string { + modelURL := downloader.ConvertURL(c.MMProj) + if downloader.LooksLikeURL(modelURL) { + return utils.MD5(modelURL) + } + + return c.MMProj +} + +func (c *BackendConfig) IsMMProjURL() bool { + return downloader.LooksLikeURL(downloader.ConvertURL(c.MMProj)) +} + +func (c *BackendConfig) IsModelURL() bool { + return downloader.LooksLikeURL(downloader.ConvertURL(c.Model)) +} + +// ModelFileName returns the filename of the model +// If the model is a URL, it will return the MD5 of the URL which is the filename +func (c *BackendConfig) ModelFileName() string { + modelURL := downloader.ConvertURL(c.Model) + if downloader.LooksLikeURL(modelURL) { + return utils.MD5(modelURL) + } + + return c.Model +} + func (c *BackendConfig) FunctionToCall() string { if c.functionCallNameString != "" && c.functionCallNameString != "none" && c.functionCallNameString != "auto" { @@ -532,16 +562,13 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { } } - modelURL := config.PredictionOptions.Model - modelURL = downloader.ConvertURL(modelURL) - - if downloader.LooksLikeURL(modelURL) { - // md5 of model name - md5Name := utils.MD5(modelURL) - + // If the model is an URL, expand it, and download the file + if config.IsModelURL() { + modelFileName := config.ModelFileName() + modelURL := downloader.ConvertURL(config.Model) // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", 0, 0, status) + if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) if err != nil { return err } @@ -549,9 +576,27 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { cc := cl.configs[i] c := &cc - c.PredictionOptions.Model = md5Name + c.PredictionOptions.Model = modelFileName cl.configs[i] = *c } + + if config.IsMMProjURL() { + modelFileName := config.MMProjFileName() + modelURL := downloader.ConvertURL(config.MMProj) + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status) + if err != nil { + return err + } + } + + cc := cl.configs[i] + c := &cc + c.MMProj = modelFileName + cl.configs[i] = *c + } + if cl.configs[i].Name != "" { glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name)) } @@ -586,7 +631,8 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C } for _, file := range files { // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { + if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") || + strings.HasPrefix(file.Name(), ".") { continue } c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...) diff --git a/core/http/elements/gallery.go b/core/http/elements/gallery.go index 6edbd23d..8093b042 100644 --- a/core/http/elements/gallery.go +++ b/core/http/elements/gallery.go @@ -13,7 +13,7 @@ const ( NoImage = "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg" ) -func DoneProgress(uid string) string { +func DoneProgress(uid, text string) string { return elem.Div( attrs.Props{}, elem.H3( @@ -23,7 +23,7 @@ func DoneProgress(uid string) string { "tabindex": "-1", "autofocus": "", }, - elem.Text("Installation completed"), + elem.Text(text), ), ).Render() } @@ -60,7 +60,7 @@ func ProgressBar(progress string) string { ).Render() } -func StartProgressBar(uid, progress string) string { +func StartProgressBar(uid, progress, text string) string { if progress == "" { progress = "0" } @@ -77,7 +77,7 @@ func StartProgressBar(uid, progress string) string { "tabindex": "-1", "autofocus": "", }, - elem.Text("Installing"), + elem.Text(text), // This is a simple example of how to use the HTMLX library to create a progress bar that updates every 600ms. elem.Div(attrs.Props{ "hx-get": "/browse/job/progress/" + uid, @@ -106,14 +106,33 @@ func cardSpan(text, icon string) elem.Node { func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[string, string]) string { //StartProgressBar(uid, "0") modelsElements := []elem.Node{} - span := func(s string) elem.Node { - return elem.Span( + // span := func(s string) elem.Node { + // return elem.Span( + // attrs.Props{ + // "class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs", + // }, + // elem.Text(s), + // ) + // } + deleteButton := func(m *gallery.GalleryModel) elem.Node { + return elem.Button( attrs.Props{ - "class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs", + "data-twe-ripple-init": "", + "data-twe-ripple-color": "light", + "class": "float-right inline-block rounded bg-red-800 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-red-accent-300 hover:shadow-red-2 focus:bg-red-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-red-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong", + "hx-swap": "outerHTML", + // post the Model ID as param + "hx-post": "/browse/delete/model/" + m.Name, }, - elem.Text(s), + elem.I( + attrs.Props{ + "class": "fa-solid fa-cancel pr-2", + }, + ), + elem.Text("Delete"), ) } + installButton := func(m *gallery.GalleryModel) elem.Node { return elem.Button( attrs.Props{ @@ -202,10 +221,14 @@ func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[stri elem.If( currentlyInstalling, elem.Node( // If currently installing, show progress bar - elem.Raw(StartProgressBar(installing.Get(galleryID), "0")), + elem.Raw(StartProgressBar(installing.Get(galleryID), "0", "Installing")), ), // Otherwise, show install button (if not installed) or display "Installed" elem.If(m.Installed, - span("Installed"), + //elem.Node(elem.Div( + // attrs.Props{}, + // span("Installed"), deleteButton(m), + // )), + deleteButton(m), installButton(m), ), ), diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go index b693e7c3..a74a2bb9 100644 --- a/core/http/endpoints/localai/gallery.go +++ b/core/http/endpoints/localai/gallery.go @@ -74,6 +74,27 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe } } +func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + modelName := c.Params("name") + + mgs.galleryApplier.C <- gallery.GalleryOp{ + Delete: true, + GalleryName: modelName, + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + + return c.JSON(struct { + ID string `json:"uuid"` + StatusURL string `json:"status"` + }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + } +} + func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 6415c894..138babbe 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -23,6 +23,8 @@ func RegisterLocalAIRoutes(app *fiber.App, modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) + app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint()) + app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint()) diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index b63b1870..2b8c6b95 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -66,6 +66,12 @@ func RegisterUIRoutes(app *fiber.App, return c.SendString(elements.ListModels(filteredModels, installingModels)) }) + /* + + Install routes + + */ + // This route is used when the "Install" button is pressed, we submit here a new job to the gallery service // https://htmx.org/examples/progress-bar/ app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error { @@ -89,7 +95,33 @@ func RegisterUIRoutes(app *fiber.App, galleryService.C <- op }() - return c.SendString(elements.StartProgressBar(uid, "0")) + return c.SendString(elements.StartProgressBar(uid, "0", "Installation")) + }) + + // This route is used when the "Install" button is pressed, we submit here a new job to the gallery service + // https://htmx.org/examples/progress-bar/ + app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error { + galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests! + + id, err := uuid.NewUUID() + if err != nil { + return err + } + + uid := id.String() + + installingModels.Set(galleryID, uid) + + op := gallery.GalleryOp{ + Id: uid, + Delete: true, + GalleryName: galleryID, + } + go func() { + galleryService.C <- op + }() + + return c.SendString(elements.StartProgressBar(uid, "0", "Deletion")) }) // Display the job current progress status @@ -118,12 +150,20 @@ func RegisterUIRoutes(app *fiber.App, // this route is hit when the job is done, and we display the // final state (for now just displays "Installation completed") app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error { + + status := galleryService.GetStatus(c.Params("uid")) + for _, k := range installingModels.Keys() { if installingModels.Get(k) == c.Params("uid") { installingModels.Delete(k) } } - return c.SendString(elements.DoneProgress(c.Params("uid"))) + displayText := "Installation completed" + if status.Deletion { + displayText = "Deletion completed" + } + + return c.SendString(elements.DoneProgress(c.Params("uid"), displayText)) }) } diff --git a/core/services/gallery.go b/core/services/gallery.go index b068abbb..6a54e38c 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "os" + "path/filepath" "strings" "sync" @@ -84,18 +85,47 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader } var err error - // if the request contains a gallery name, we apply the gallery from the gallery list - if op.GalleryName != "" { - if strings.Contains(op.GalleryName, "@") { - err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) - } else { - err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) + + // delete a model + if op.Delete { + modelConfig := &config.BackendConfig{} + // Galleryname is the name of the model in this case + dat, err := os.ReadFile(filepath.Join(g.modelPath, op.GalleryName+".yaml")) + if err != nil { + updateError(err) + continue } - } else if op.ConfigURL != "" { - startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) - err = cl.Preload(g.modelPath) + err = yaml.Unmarshal(dat, modelConfig) + if err != nil { + updateError(err) + continue + } + + files := []string{} + // Remove the model from the config + if modelConfig.Model != "" { + files = append(files, modelConfig.ModelFileName()) + } + + if modelConfig.MMProj != "" { + files = append(files, modelConfig.MMProjFileName()) + } + + err = gallery.DeleteModelFromSystem(g.modelPath, op.GalleryName, files) } else { - err = prepareModel(g.modelPath, op.Req, cl, progressCallback) + // if the request contains a gallery name, we apply the gallery from the gallery list + if op.GalleryName != "" { + if strings.Contains(op.GalleryName, "@") { + err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) + } else { + err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) + } + } else if op.ConfigURL != "" { + startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL) + err = cl.Preload(g.modelPath) + } else { + err = prepareModel(g.modelPath, op.Req, cl, progressCallback) + } } if err != nil { @@ -116,7 +146,12 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader continue } - g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100}) + g.UpdateStatus(op.Id, + &gallery.GalleryOpStatus{ + Deletion: op.Delete, + Processed: true, + Message: "completed", + Progress: 100}) } } }() diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index c4575817..d90ce4d9 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -1,6 +1,7 @@ package gallery import ( + "errors" "fmt" "os" "path/filepath" @@ -184,3 +185,48 @@ func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) } return models, nil } + +func DeleteModelFromSystem(basePath string, name string, additionalFiles []string) error { + // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + + configFile := filepath.Join(basePath, fmt.Sprintf("%s.yaml", name)) + + galleryFile := filepath.Join(basePath, galleryFileName(name)) + + var err error + // Delete all the files associated to the model + // read the model config + galleryconfig, err := ReadConfigFile(galleryFile) + if err != nil { + log.Error().Err(err).Msgf("failed to read gallery file %s", configFile) + } + + // Remove additional files + if galleryconfig != nil { + for _, f := range galleryconfig.Files { + fullPath := filepath.Join(basePath, f.Filename) + log.Debug().Msgf("Removing file %s", fullPath) + if e := os.Remove(fullPath); e != nil { + err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f.Filename, e)) + } + } + } + + for _, f := range additionalFiles { + fullPath := filepath.Join(filepath.Join(basePath, f)) + log.Debug().Msgf("Removing additional file %s", fullPath) + if e := os.Remove(fullPath); e != nil { + err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", f, e)) + } + } + + log.Debug().Msgf("Removing model config file %s", configFile) + + // Delete the model config file + if e := os.Remove(configFile); e != nil { + err = errors.Join(err, fmt.Errorf("failed to remove file %s: %w", configFile, e)) + } + + return err +} diff --git a/pkg/gallery/gallery_suite_test.go b/pkg/gallery/gallery_suite_test.go index 44256bc2..bf13cac9 100644 --- a/pkg/gallery/gallery_suite_test.go +++ b/pkg/gallery/gallery_suite_test.go @@ -1,6 +1,7 @@ package gallery_test import ( + "os" "testing" . "github.com/onsi/ginkgo/v2" @@ -11,3 +12,9 @@ func TestGallery(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Gallery test suite") } + +var _ = BeforeSuite(func() { + if os.Getenv("FIXTURES") == "" { + Fail("FIXTURES env var not set") + } +}) diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 2ab4c832..1fc6c0a2 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -178,5 +178,20 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides log.Debug().Msgf("Written config file %s", configFilePath) } - return nil + // Save the model gallery file for further reference + modelFile := filepath.Join(basePath, galleryFileName(name)) + data, err := yaml.Marshal(config) + if err != nil { + return err + } + + log.Debug().Msgf("Written gallery file %s", modelFile) + + return os.WriteFile(modelFile, data, 0600) + + //return nil +} + +func galleryFileName(name string) string { + return "._gallery_" + name + ".yaml" } diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index 6eb63128..bfc2b9a6 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -1,6 +1,7 @@ package gallery_test import ( + "errors" "os" "path/filepath" @@ -11,6 +12,7 @@ import ( ) var _ = Describe("Model test", func() { + Context("Downloading", func() { It("applies model correctly", func() { tempdir, err := os.MkdirTemp("", "test") @@ -80,6 +82,19 @@ var _ = Describe("Model test", func() { Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Installed).To(BeTrue()) + + // delete + err = DeleteModelFromSystem(tempdir, "bert", []string{}) + Expect(err).ToNot(HaveOccurred()) + + models, err = AvailableGalleryModels(galleries, tempdir) + Expect(err).ToNot(HaveOccurred()) + Expect(len(models)).To(Equal(1)) + Expect(models[0].Installed).To(BeFalse()) + + _, err = os.Stat(filepath.Join(tempdir, "bert.yaml")) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue()) }) It("renames model correctly", func() { diff --git a/pkg/gallery/op.go b/pkg/gallery/op.go index 73d748bf..4637820a 100644 --- a/pkg/gallery/op.go +++ b/pkg/gallery/op.go @@ -4,12 +4,14 @@ type GalleryOp struct { Id string GalleryName string ConfigURL string + Delete bool Req GalleryModel Galleries []Gallery } type GalleryOpStatus struct { + Deletion bool `json:"deletion"` // Deletion is true if the operation is a deletion FileName string `json:"file_name"` Error error `json:"error"` Processed bool `json:"processed"` diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 1b5c9aa0..2d6b3acb 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -96,7 +96,13 @@ func (ml *ModelLoader) ListModels() ([]string, error) { models := []string{} for _, file := range files { // Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method? - if strings.HasSuffix(file.Name(), ".tmpl") || strings.HasSuffix(file.Name(), ".keep") || strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") || strings.HasSuffix(file.Name(), ".json") || strings.HasSuffix(file.Name(), ".DS_Store") { + if strings.HasSuffix(file.Name(), ".tmpl") || + strings.HasSuffix(file.Name(), ".keep") || + strings.HasSuffix(file.Name(), ".yaml") || + strings.HasSuffix(file.Name(), ".yml") || + strings.HasSuffix(file.Name(), ".json") || + strings.HasSuffix(file.Name(), ".DS_Store") || + strings.HasPrefix(file.Name(), ".") { continue }