From 6ac5d814fbb5faa26ed041ab5f7864441f431eef Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 28 Jan 2024 00:14:16 +0100 Subject: [PATCH] feat(startup): fetch model definition remotely (#1654) --- api/api.go | 2 +- api/options/options.go | 8 ++++++++ embedded/embedded.go | 15 +++++++++++++++ main.go | 11 +++++++++++ pkg/startup/model_preload.go | 16 ++++++++++++++-- pkg/startup/model_preload_test.go | 22 +++++++++++++++++++--- 6 files changed, 68 insertions(+), 6 deletions(-) diff --git a/api/api.go b/api/api.go index 82e0f69b..7ec95f1b 100644 --- a/api/api.go +++ b/api/api.go @@ -37,7 +37,7 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - startup.PreloadModelsConfigurations(options.Loader.ModelPath, options.ModelsURL...) + startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...) cl := config.NewConfigLoader() if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { diff --git a/api/options/options.go b/api/options/options.go index e83eaaad..8c066584 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -28,6 +28,8 @@ type Option struct { ApiKeys []string Metrics *metrics.Metrics + ModelLibraryURL string + Galleries []gallery.Gallery BackendAssets embed.FS @@ -78,6 +80,12 @@ func WithCors(b bool) AppOption { } } +func WithModelLibraryURL(url string) AppOption { + return func(o *Option) { + o.ModelLibraryURL = url + } +} + var EnableWatchDog = func(o *Option) { o.WatchDog = true } diff --git a/embedded/embedded.go b/embedded/embedded.go index a76e87cd..c779fc26 100644 --- a/embedded/embedded.go +++ b/embedded/embedded.go @@ -6,6 +6,8 @@ import ( "slices" "strings" + "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/assets" "gopkg.in/yaml.v3" ) @@ -30,6 +32,19 @@ func init() { yaml.Unmarshal(modelLibrary, &modelShorteners) } +func GetRemoteLibraryShorteners(url string) (map[string]string, error) { + remoteLibrary := map[string]string{} + + err := downloader.GetURI(url, func(_ string, i []byte) error { + return yaml.Unmarshal(i, &remoteLibrary) + }) + if err != nil { + return nil, fmt.Errorf("error downloading remote library: %s", err.Error()) + } + + return remoteLibrary, err +} + // ExistsInModelsLibrary checks if a model exists in the embedded models library func ExistsInModelsLibrary(s string) bool { f := fmt.Sprintf("%s.yaml", s) diff --git a/main.go b/main.go index 39e38686..d2209285 100644 --- a/main.go +++ b/main.go @@ -26,6 +26,10 @@ import ( "github.com/urfave/cli/v2" ) +const ( + remoteLibraryURL = "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" +) + func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) // clean up process @@ -94,6 +98,12 @@ func main() { Usage: "JSON list of galleries", EnvVars: []string{"GALLERIES"}, }, + &cli.StringFlag{ + Name: "remote-library", + Usage: "A LocalAI remote library URL", + EnvVars: []string{"REMOTE_LIBRARY"}, + Value: remoteLibraryURL, + }, &cli.StringFlag{ Name: "preload-models", Usage: "A List of models to apply in JSON at start", @@ -219,6 +229,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithAudioDir(ctx.String("audio-path")), options.WithF16(ctx.Bool("f16")), options.WithStringGalleries(ctx.String("galleries")), + options.WithModelLibraryURL(ctx.String("remote-library")), options.WithDisableMessage(false), options.WithCors(ctx.Bool("cors")), options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index c23b7b41..cc514334 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -14,10 +14,22 @@ import ( // PreloadModelsConfigurations will preload models from the given list of URLs // It will download the model if it is not already present in the model path // It will also try to resolve if the model is an embedded model YAML configuration -func PreloadModelsConfigurations(modelPath string, models ...string) { +func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { for _, url := range models { - url = embedded.ModelShortURL(url) + // As a best effort, try to resolve the model from the remote library + // if it's not resolved we try with the other method below + if modelLibraryURL != "" { + lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL) + if err == nil { + if lib[url] != "" { + log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) + url = lib[url] + } + } + } + + url = embedded.ModelShortURL(url) switch { case embedded.ExistsInModelsLibrary(url): modelYAML, err := embedded.ResolveContent(url) diff --git a/pkg/startup/model_preload_test.go b/pkg/startup/model_preload_test.go index d1e0eab3..63a8f8b0 100644 --- a/pkg/startup/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -15,13 +15,29 @@ import ( var _ = Describe("Preload test", func() { Context("Preloading from strings", func() { + It("loads from remote url", func() { + tmpdir, err := os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" + fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719") + + PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2") + + resultFile := filepath.Join(tmpdir, fileName) + + content, err := os.ReadFile(resultFile) + Expect(err).ToNot(HaveOccurred()) + + Expect(string(content)).To(ContainSubstring("name: phi-2")) + }) + It("loads from embedded full-urls", func() { tmpdir, err := os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - PreloadModelsConfigurations(tmpdir, url) + PreloadModelsConfigurations("", tmpdir, url) resultFile := filepath.Join(tmpdir, fileName) @@ -35,7 +51,7 @@ var _ = Describe("Preload test", func() { Expect(err).ToNot(HaveOccurred()) url := "phi-2" - PreloadModelsConfigurations(tmpdir, url) + PreloadModelsConfigurations("", tmpdir, url) entry, err := os.ReadDir(tmpdir) Expect(err).ToNot(HaveOccurred()) @@ -53,7 +69,7 @@ var _ = Describe("Preload test", func() { url := "mistral-openorca" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - PreloadModelsConfigurations(tmpdir, url) + PreloadModelsConfigurations("", tmpdir, url) resultFile := filepath.Join(tmpdir, fileName)