diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 64182e75..35e0776d 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -238,7 +238,13 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { if cfg.MMap == nil { // MMap is enabled by default - cfg.MMap = &trueV + + // Only exception is for Intel GPUs + if os.Getenv("XPU") != "" { + cfg.MMap = &falseV + } else { + cfg.MMap = &trueV + } } if cfg.MMlock == nil { diff --git a/core/http/elements/gallery.go b/core/http/elements/gallery.go index c03750da..6edbd23d 100644 --- a/core/http/elements/gallery.go +++ b/core/http/elements/gallery.go @@ -6,6 +6,7 @@ import ( "github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go/attrs" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/xsync" ) const ( @@ -102,7 +103,8 @@ func cardSpan(text, icon string) elem.Node { ) } -func ListModels(models []*gallery.GalleryModel) string { +func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[string, string]) string { + //StartProgressBar(uid, "0") modelsElements := []elem.Node{} span := func(s string) elem.Node { return elem.Span( @@ -118,6 +120,7 @@ func ListModels(models []*gallery.GalleryModel) string { "data-twe-ripple-init": "", "data-twe-ripple-color": "light", "class": "float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong", + "hx-swap": "outerHTML", // post the Model ID as param "hx-post": "/browse/install/model/" + fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name), }, @@ -152,6 +155,9 @@ func ListModels(models []*gallery.GalleryModel) string { } actionDiv := func(m *gallery.GalleryModel) elem.Node { + galleryID := fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name) + currentlyInstalling := installing.Exists(galleryID) + nodes := []elem.Node{ cardSpan("Repository: "+m.Gallery.Name, "fa-brands fa-git-alt"), } @@ -193,7 +199,16 @@ func ListModels(models []*gallery.GalleryModel) string { }, nodes..., ), - elem.If(m.Installed, span("Installed"), installButton(m)), + elem.If( + currentlyInstalling, + elem.Node( // If currently installing, show progress bar + elem.Raw(StartProgressBar(installing.Get(galleryID), "0")), + ), // Otherwise, show install button (if not installed) or display "Installed" + elem.If(m.Installed, + span("Installed"), + installButton(m), + ), + ), ) } diff --git a/core/http/routes/ui.go b/core/http/routes/ui.go index c64ec5ff..b63b1870 100644 --- a/core/http/routes/ui.go +++ b/core/http/routes/ui.go @@ -10,6 +10,8 @@ import ( "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/xsync" + "github.com/gofiber/fiber/v2" "github.com/google/uuid" ) @@ -21,13 +23,16 @@ func RegisterUIRoutes(app *fiber.App, galleryService *services.GalleryService, auth func(*fiber.Ctx) error) { - // Show the Models page + // keeps the state of models that are being installed from the UI + var installingModels = xsync.NewSyncedMap[string, string]() + + // Show the Models page (all models) app.Get("/browse", auth, func(c *fiber.Ctx) error { models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath) summary := fiber.Map{ "Title": "LocalAI - Models", - "Models": template.HTML(elements.ListModels(models)), + "Models": template.HTML(elements.ListModels(models, installingModels)), "Repositories": appConfig.Galleries, // "ApplicationConfig": appConfig, } @@ -36,7 +41,7 @@ func RegisterUIRoutes(app *fiber.App, return c.Render("views/models", summary) }) - // HTMX: return the model details + // Show the models, filtered from the user input // https://htmx.org/examples/active-search/ app.Post("/browse/search/models", auth, func(c *fiber.Ctx) error { form := struct { @@ -58,12 +63,13 @@ func RegisterUIRoutes(app *fiber.App, } } - return c.SendString(elements.ListModels(filteredModels)) + return c.SendString(elements.ListModels(filteredModels, installingModels)) }) + // This route is used when the "Install" button is pressed, we submit here a new job to the gallery service // https://htmx.org/examples/progress-bar/ app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error { - galleryID := strings.Clone(c.Params("id")) // strings.Clone is required! + galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests! id, err := uuid.NewUUID() if err != nil { @@ -72,6 +78,8 @@ func RegisterUIRoutes(app *fiber.App, uid := id.String() + installingModels.Set(galleryID, uid) + op := gallery.GalleryOp{ Id: uid, GalleryName: galleryID, @@ -84,6 +92,8 @@ func RegisterUIRoutes(app *fiber.App, return c.SendString(elements.StartProgressBar(uid, "0")) }) + // Display the job current progress status + // If the job is done, we trigger the /browse/job/:uid route // https://htmx.org/examples/progress-bar/ app.Get("/browse/job/progress/:uid", auth, func(c *fiber.Ctx) error { jobUID := c.Params("uid") @@ -95,7 +105,7 @@ func RegisterUIRoutes(app *fiber.App, } if status.Progress == 100 { - c.Set("HX-Trigger", "done") + c.Set("HX-Trigger", "done") // this triggers /browse/job/:uid (which is when the job is done) return c.SendString(elements.ProgressBar("100")) } if status.Error != nil { @@ -105,7 +115,15 @@ func RegisterUIRoutes(app *fiber.App, return c.SendString(elements.ProgressBar(fmt.Sprint(status.Progress))) }) + // this route is hit when the job is done, and we display the + // final state (for now just displays "Installation completed") app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error { + for _, k := range installingModels.Keys() { + if installingModels.Get(k) == c.Params("uid") { + installingModels.Delete(k) + } + } + return c.SendString(elements.DoneProgress(c.Params("uid"))) }) } diff --git a/pkg/xsync/map.go b/pkg/xsync/map.go new file mode 100644 index 00000000..9c3a471e --- /dev/null +++ b/pkg/xsync/map.go @@ -0,0 +1,77 @@ +package xsync + +import ( + "sync" +) + +type SyncedMap[K comparable, V any] struct { + mu sync.RWMutex + m map[K]V +} + +func NewSyncedMap[K comparable, V any]() *SyncedMap[K, V] { + return &SyncedMap[K, V]{ + m: make(map[K]V), + } +} + +func (m *SyncedMap[K, V]) Get(key K) V { + m.mu.RLock() + defer m.mu.RUnlock() + return m.m[key] +} + +func (m *SyncedMap[K, V]) Keys() []K { + m.mu.RLock() + defer m.mu.RUnlock() + keys := make([]K, 0, len(m.m)) + for k := range m.m { + keys = append(keys, k) + } + return keys +} + +func (m *SyncedMap[K, V]) Values() []V { + m.mu.RLock() + defer m.mu.RUnlock() + values := make([]V, 0, len(m.m)) + for _, v := range m.m { + values = append(values, v) + } + return values +} + +func (m *SyncedMap[K, V]) Len() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.m) +} + +func (m *SyncedMap[K, V]) Iterate(f func(key K, value V) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + for k, v := range m.m { + if !f(k, v) { + break + } + } +} + +func (m *SyncedMap[K, V]) Set(key K, value V) { + m.mu.Lock() + m.m[key] = value + m.mu.Unlock() +} + +func (m *SyncedMap[K, V]) Delete(key K) { + m.mu.Lock() + delete(m.m, key) + m.mu.Unlock() +} + +func (m *SyncedMap[K, V]) Exists(key K) bool { + m.mu.RLock() + defer m.mu.RUnlock() + _, ok := m.m[key] + return ok +} diff --git a/pkg/xsync/map_test.go b/pkg/xsync/map_test.go new file mode 100644 index 00000000..a7ecfbcc --- /dev/null +++ b/pkg/xsync/map_test.go @@ -0,0 +1,26 @@ +package xsync_test + +import ( + . "github.com/go-skynet/LocalAI/pkg/xsync" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("SyncMap", func() { + + Context("Syncmap", func() { + It("sets and gets", func() { + m := NewSyncedMap[string, string]() + m.Set("foo", "bar") + Expect(m.Get("foo")).To(Equal("bar")) + }) + It("deletes", func() { + m := NewSyncedMap[string, string]() + m.Set("foo", "bar") + m.Delete("foo") + Expect(m.Get("foo")).To(Equal("")) + Expect(m.Exists("foo")).To(Equal(false)) + }) + }) +}) diff --git a/pkg/xsync/sync_suite_test.go b/pkg/xsync/sync_suite_test.go new file mode 100644 index 00000000..0dad9c66 --- /dev/null +++ b/pkg/xsync/sync_suite_test.go @@ -0,0 +1,13 @@ +package xsync_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestSync(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "LocalAI sync test") +}