From 0d8bf91699a9deee596011cb1c30be29ec680685 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 23 Apr 2024 09:22:58 +0200 Subject: [PATCH] feat: Galleries UI (#2104) * WIP: add models to webui Signed-off-by: Ettore Di Giacinto * Register routes Signed-off-by: Ettore Di Giacinto * fix: don't cache models Signed-off-by: Ettore Di Giacinto * small fixups Signed-off-by: Ettore Di Giacinto * fix: fixup multiple installs (strings.Clone) Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- README.md | 2 +- core/config/backend_config.go | 6 +- core/http/app.go | 6 +- core/http/elements/gallery.go | 171 +++++++++++++++++++++++++ core/http/endpoints/localai/welcome.go | 6 +- core/http/routes/localai.go | 3 +- core/http/routes/ui.go | 107 ++++++++++++++++ core/http/routes/welcome.go | 6 +- core/http/views/models.html | 40 ++++++ core/http/views/partials/head.html | 67 +++++++++- core/http/views/partials/navbar.html | 1 + docs/content/docs/overview.md | 2 +- go.mod | 5 +- go.sum | 2 + main.go | 2 +- pkg/downloader/progress.go | 13 ++ pkg/downloader/uri.go | 4 +- pkg/gallery/models.go | 4 +- pkg/gallery/op.go | 5 +- pkg/startup/model_preload.go | 2 +- 20 files changed, 431 insertions(+), 23 deletions(-) create mode 100644 core/http/elements/gallery.go create mode 100644 core/http/routes/ui.go create mode 100644 core/http/views/models.html diff --git a/README.md b/README.md index e28e3cb0..0b32febd 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ [![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai) -**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that’s compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. +**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that’s compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler). ## 🔥🔥 Hot topics / Roadmap diff --git a/core/config/backend_config.go b/core/config/backend_config.go index dfc216dc..64182e75 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -512,7 +512,7 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { for i, config := range cl.configs { // Download files and verify their SHA - for _, file := range config.DownloadFiles { + for i, file := range config.DownloadFiles { log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) if err := utils.VerifyPath(file.Filename, modelPath); err != nil { @@ -521,7 +521,7 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { // Create file path filePath := filepath.Join(modelPath, file.Filename) - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { + if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil { return err } } @@ -535,7 +535,7 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error { // check if file exists if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) + err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", 0, 0, status) if err != nil { return err } diff --git a/core/http/app.go b/core/http/app.go index 1061627f..21652dd9 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -186,10 +186,14 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) + galleryService := services.NewGalleryService(appConfig.ModelPath) + galleryService.Start(appConfig.Context, cl) + routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth) - routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, auth) + routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService, auth) routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth) routes.RegisterPagesRoutes(app, cl, ml, appConfig, auth) + routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth) // Define a custom 404 handler // Note: keep this at the bottom! diff --git a/core/http/elements/gallery.go b/core/http/elements/gallery.go new file mode 100644 index 00000000..370ca82d --- /dev/null +++ b/core/http/elements/gallery.go @@ -0,0 +1,171 @@ +package elements + +import ( + "fmt" + + "github.com/chasefleming/elem-go" + "github.com/chasefleming/elem-go/attrs" + "github.com/go-skynet/LocalAI/pkg/gallery" +) + +func DoneProgress(uid string) string { + return elem.Div( + attrs.Props{}, + elem.H3( + attrs.Props{ + "role": "status", + "id": "pblabel", + "tabindex": "-1", + "autofocus": "", + }, + elem.Text("Installation completed"), + ), + ).Render() +} + +func ErrorProgress(err string) string { + return elem.Div( + attrs.Props{}, + elem.H3( + attrs.Props{ + "role": "status", + "id": "pblabel", + "tabindex": "-1", + "autofocus": "", + }, + elem.Text("Error"+err), + ), + ).Render() +} + +func ProgressBar(progress string) string { + return elem.Div(attrs.Props{ + "class": "progress", + "role": "progressbar", + "aria-valuemin": "0", + "aria-valuemax": "100", + "aria-valuenow": "0", + "aria-labelledby": "pblabel", + }, + elem.Div(attrs.Props{ + "id": "pb", + "class": "progress-bar", + "style": "width:" + progress + "%", + }), + ).Render() +} + +func StartProgressBar(uid, progress string) string { + if progress == "" { + progress = "0" + } + return elem.Div(attrs.Props{ + "hx-trigger": "done", + "hx-get": "/browse/job/" + uid, + "hx-swap": "outerHTML", + "hx-target": "this", + }, + elem.H3( + attrs.Props{ + "role": "status", + "id": "pblabel", + "tabindex": "-1", + "autofocus": "", + }, + elem.Text("Installing"), + // This is a simple example of how to use the HTMLX library to create a progress bar that updates every 600ms. + elem.Div(attrs.Props{ + "hx-get": "/browse/job/progress/" + uid, + "hx-trigger": "every 600ms", + "hx-target": "this", + "hx-swap": "innerHTML", + }, + elem.Raw(ProgressBar(progress)), + ), + ), + ).Render() +} + +func ListModels(models []*gallery.GalleryModel) string { + modelsElements := []elem.Node{} + span := func(s string) elem.Node { + return elem.Span( + attrs.Props{ + "class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs", + }, + elem.Text(s), + ) + } + installButton := func(m *gallery.GalleryModel) elem.Node { + return elem.Button( + attrs.Props{ + "class": "float-right inline-block rounded bg-primary px-6 pb-2 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong", + // post the Model ID as param + "hx-post": "/browse/install/model/" + fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name), + }, + elem.Text("Install"), + ) + } + + descriptionDiv := func(m *gallery.GalleryModel) elem.Node { + + return elem.Div( + attrs.Props{ + "class": "p-6", + }, + elem.H5( + attrs.Props{ + "class": "mb-2 text-xl font-medium leading-tight", + }, + elem.Text(m.Name), + ), + elem.P( + attrs.Props{ + "class": "mb-4 text-base", + }, + elem.Text(m.Description), + ), + ) + } + + actionDiv := func(m *gallery.GalleryModel) elem.Node { + return elem.Div( + attrs.Props{ + "class": "px-6 pt-4 pb-2", + }, + elem.Span( + attrs.Props{ + "class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2", + }, + elem.Text("Repository: "+m.Gallery.Name), + ), + elem.If(m.Installed, span("Installed"), installButton(m)), + ) + } + + for _, m := range models { + modelsElements = append(modelsElements, + elem.Div( + attrs.Props{ + "class": "me-4 mb-2 block rounded-lg bg-white shadow-secondary-1 dark:bg-gray-800 dark:bg-surface-dark dark:text-white text-surface p-2", + }, + elem.Div( + attrs.Props{ + "class": "p-6", + }, + descriptionDiv(m), + actionDiv(m), + // elem.If(m.Installed, span("Installed"), installButton(m)), + + // elem.If(m.Installed, span("Installed"), span("Not Installed")), + ), + ), + ) + } + + wrapper := elem.Div(attrs.Props{ + "class": "dark grid grid-cols-1 grid-rows-1 md:grid-cols-2 ", + }, modelsElements...) + + return wrapper.Render() +} diff --git a/core/http/endpoints/localai/welcome.go b/core/http/endpoints/localai/welcome.go index fd3e6230..291422c6 100644 --- a/core/http/endpoints/localai/welcome.go +++ b/core/http/endpoints/localai/welcome.go @@ -3,12 +3,16 @@ package localai import ( "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/internal" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) func WelcomeEndpoint(appConfig *config.ApplicationConfig, - models []string, backendConfigs []config.BackendConfig) func(*fiber.Ctx) error { + cl *config.BackendConfigLoader, ml *model.ModelLoader) func(*fiber.Ctx) error { return func(c *fiber.Ctx) error { + models, _ := ml.ListModels() + backendConfigs := cl.GetAllBackendConfigs() + summary := fiber.Map{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 2651a53e..6415c894 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -14,13 +14,12 @@ func RegisterLocalAIRoutes(app *fiber.App, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, + galleryService *services.GalleryService, auth func(*fiber.Ctx) error) { app.Get("/swagger/*", swagger.HandlerDefault) // default // LocalAI API endpoints - galleryService := services.NewGalleryService(appConfig.ModelPath) - galleryService.Start(appConfig.Context, cl) modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go new file mode 100644 index 00000000..b9ccd89a --- /dev/null +++ b/core/http/routes/ui.go @@ -0,0 +1,107 @@ +package routes + +import ( + "fmt" + "html/template" + "strings" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/http/elements" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" +) + +func RegisterUIRoutes(app *fiber.App, + cl *config.BackendConfigLoader, + ml *model.ModelLoader, + appConfig *config.ApplicationConfig, + galleryService *services.GalleryService, + auth func(*fiber.Ctx) error) { + + // Show the Models page + app.Get("/browse", auth, func(c *fiber.Ctx) error { + models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath) + + summary := fiber.Map{ + "Title": "LocalAI API - Models", + "Models": template.HTML(elements.ListModels(models)), + // "ApplicationConfig": appConfig, + } + + // Render index + return c.Render("views/models", summary) + }) + + // HTMX: return the model details + // https://htmx.org/examples/active-search/ + app.Post("/browse/search/models", auth, func(c *fiber.Ctx) error { + form := struct { + Search string `form:"search"` + }{} + if err := c.BodyParser(&form); err != nil { + return c.Status(fiber.StatusBadRequest).SendString(err.Error()) + } + + models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath) + + filteredModels := []*gallery.GalleryModel{} + for _, m := range models { + if strings.Contains(m.Name, form.Search) { + filteredModels = append(filteredModels, m) + } + } + + return c.SendString(elements.ListModels(filteredModels)) + }) + + // https://htmx.org/examples/progress-bar/ + app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error { + galleryID := strings.Clone(c.Params("id")) // strings.Clone is required! + + id, err := uuid.NewUUID() + if err != nil { + return err + } + + uid := id.String() + + op := gallery.GalleryOp{ + Id: uid, + GalleryName: galleryID, + Galleries: appConfig.Galleries, + } + go func() { + galleryService.C <- op + }() + + return c.SendString(elements.StartProgressBar(uid, "0")) + }) + + // https://htmx.org/examples/progress-bar/ + app.Get("/browse/job/progress/:uid", auth, func(c *fiber.Ctx) error { + jobUID := c.Params("uid") + + status := galleryService.GetStatus(jobUID) + if status == nil { + //fmt.Errorf("could not find any status for ID") + return c.SendString(elements.ProgressBar("0")) + } + + if status.Progress == 100 { + c.Set("HX-Trigger", "done") + return c.SendString(elements.ProgressBar("100")) + } + if status.Error != nil { + return c.SendString(elements.ErrorProgress(status.Error.Error())) + } + + return c.SendString(elements.ProgressBar(fmt.Sprint(status.Progress))) + }) + + app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error { + return c.SendString(elements.DoneProgress(c.Params("uid"))) + }) +} diff --git a/core/http/routes/welcome.go b/core/http/routes/welcome.go index 29b9e586..6b600d2d 100644 --- a/core/http/routes/welcome.go +++ b/core/http/routes/welcome.go @@ -13,11 +13,7 @@ func RegisterPagesRoutes(app *fiber.App, appConfig *config.ApplicationConfig, auth func(*fiber.Ctx) error) { - models, _ := ml.ListModels() - backendConfigs := cl.GetAllBackendConfigs() - if !appConfig.DisableWelcomePage { - app.Get("/", auth, localai.WelcomeEndpoint(appConfig, models, backendConfigs)) + app.Get("/", auth, localai.WelcomeEndpoint(appConfig, cl, ml)) } - } diff --git a/core/http/views/models.html b/core/http/views/models.html new file mode 100644 index 00000000..63c6bba0 --- /dev/null +++ b/core/http/views/models.html @@ -0,0 +1,40 @@ + + +{{template "views/partials/head" .}} + + +
+ + {{template "views/partials/navbar" .}} +
+
+

Welcome to your LocalAI instance!

+
+ +
+

The FOSS alternative to OpenAI, Claude, ...

+ + Documentation + +
+ +
+

Available models from repositories

+ + + + +
{{.Models}}
+
+
+ + {{template "views/partials/footer" .}} +
+ + + diff --git a/core/http/views/partials/head.html b/core/http/views/partials/head.html index 59cdea33..9dbfecdb 100644 --- a/core/http/views/partials/head.html +++ b/core/http/views/partials/head.html @@ -3,11 +3,76 @@ {{.Title}} - + + + + + \ No newline at end of file diff --git a/core/http/views/partials/navbar.html b/core/http/views/partials/navbar.html index c3d3223f..36332ed2 100644 --- a/core/http/views/partials/navbar.html +++ b/core/http/views/partials/navbar.html @@ -9,6 +9,7 @@ diff --git a/docs/content/docs/overview.md b/docs/content/docs/overview.md index 5224bc49..f0f59494 100644 --- a/docs/content/docs/overview.md +++ b/docs/content/docs/overview.md @@ -56,7 +56,7 @@ icon = "info" -**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI API specifications for local inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families and architectures. Does not require GPU. It is maintained by [mudler](https://github.com/mudler). +**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI API specifications for local inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families and architectures. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler). ## Start LocalAI diff --git a/go.mod b/go.mod index 0bf9aa02..9485383e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/go-skynet/LocalAI -go 1.21 +go 1.21.1 + +toolchain go1.22.2 require ( github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf @@ -71,6 +73,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chasefleming/elem-go v0.25.0 // indirect github.com/containerd/continuity v0.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.8.1 // indirect diff --git a/go.sum b/go.sum index 55fdaf06..b68834b2 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/glamour v0.7.0 h1:2BtKGZ4iVJCDfMF229EzbeR1QRKLWztO9dMtjmqZSng= github.com/charmbracelet/glamour v0.7.0/go.mod h1:jUMh5MeihljJPQbJ/wf4ldw2+yBP59+ctV36jASy7ps= +github.com/chasefleming/elem-go v0.25.0 h1:LYzr1auk39Bh3bdKloArOFV7sOBnOfSOKxsg58eWL0Q= +github.com/chasefleming/elem-go v0.25.0/go.mod h1:hz73qILBIKnTgOujnSMtEj20/epI+f6vg71RUilJAA4= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= diff --git a/main.go b/main.go index 9976906b..04f13d3f 100644 --- a/main.go +++ b/main.go @@ -72,7 +72,7 @@ Version: ${version} kong.Vars{ "basepath": kong.ExpandPath("."), "remoteLibraryURL": "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml", - "galleries": `[{"name":"localai", "url":"github:mudler/LocalAI/gallery/index.yaml"}]`, + "galleries": `[{"name":"localai", "url":"github:mudler/LocalAI/gallery/index.yaml@master"}]`, "version": internal.PrintableVersion(), }, ) diff --git a/pkg/downloader/progress.go b/pkg/downloader/progress.go index 6806f586..6cd6132b 100644 --- a/pkg/downloader/progress.go +++ b/pkg/downloader/progress.go @@ -5,6 +5,8 @@ import "hash" type progressWriter struct { fileName string total int64 + fileNo int + totalFiles int written int64 downloadStatus func(string, string, string, float64) hash hash.Hash @@ -16,6 +18,17 @@ func (pw *progressWriter) Write(p []byte) (n int, err error) { if pw.total > 0 { percentage := float64(pw.written) / float64(pw.total) * 100 + if pw.totalFiles > 1 { + // This is a multi-file download + // so we need to adjust the percentage + // to reflect the progress of the whole download + // This is the file pw.fileNo of pw.totalFiles files. We assume that + // the files before successfully downloaded. + percentage = percentage / float64(pw.totalFiles) + if pw.fileNo > 1 { + percentage += float64(pw.fileNo-1) * 100 / float64(pw.totalFiles) + } + } //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) } else { diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index b678ae0d..46ccd6a1 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -136,7 +136,7 @@ func removePartialFile(tmpFilePath string) error { return nil } -func DownloadFile(url string, filePath, sha string, downloadStatus func(string, string, string, float64)) error { +func DownloadFile(url string, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { url = ConvertURL(url) // Check if the file already exists _, err := os.Stat(filePath) @@ -209,6 +209,8 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string, fileName: tmpFilePath, total: resp.ContentLength, hash: sha256.New(), + fileNo: fileN, + totalFiles: total, downloadStatus: downloadStatus, } _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 10caedee..59971bbc 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -102,7 +102,7 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides } // Download files and verify their SHA - for _, file := range config.Files { + for i, file := range config.Files { log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) if err := utils.VerifyPath(file.Filename, basePath); err != nil { @@ -111,7 +111,7 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides // Create file path filePath := filepath.Join(basePath, file.Filename) - if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil { + if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { return err } } diff --git a/pkg/gallery/op.go b/pkg/gallery/op.go index 99796812..73d748bf 100644 --- a/pkg/gallery/op.go +++ b/pkg/gallery/op.go @@ -1,11 +1,12 @@ package gallery type GalleryOp struct { - Req GalleryModel Id string - Galleries []Gallery GalleryName string ConfigURL string + + Req GalleryModel + Galleries []Gallery } type GalleryOpStatus struct { diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index b09516a7..d267d846 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -54,7 +54,7 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model // check if file exists if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - err := downloader.DownloadFile(url, modelDefinitionFilePath, "", func(fileName, current, total string, percent float64) { + err := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) }) if err != nil {