From ce8e9dc690bd6c4091c06e26c1d5aabe34a0667f Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 31 Jul 2023 13:14:32 -0400 Subject: [PATCH] feature: model list :: filter query string parameter (#830) --- api/api_test.go | 15 +++++---------- api/config/config.go | 10 ++++++++++ api/openai/list.go | 43 +++++++++++++++++++++++++++++++++++++------ pkg/model/loader.go | 4 ++-- 4 files changed, 54 insertions(+), 18 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index c88002be..a58165c3 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -534,7 +534,7 @@ var _ = Describe("API test", func() { It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(11)) + Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8? }) It("can generate completions", func() { resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"}) @@ -738,19 +738,14 @@ var _ = Describe("API test", func() { cancel() app.Shutdown() }) - It("can generate chat completions from config file", func() { - models, err := client.ListModels(context.TODO()) - Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(13)) - }) - It("can generate chat completions from config file", func() { - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) + It("can generate chat completions from config file (list1)", func() { + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) - It("can generate chat completions from config file", func() { - resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) + It("can generate chat completions from config file (list2)", func() { + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) diff --git a/api/config/config.go b/api/config/config.go index 2d4ab2ce..75afaabf 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -172,6 +172,16 @@ func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { return v, exists } +func (cm *ConfigLoader) GetAllConfigs() []Config { + cm.Lock() + defer cm.Unlock() + var res []Config + for _, v := range cm.configs { + res = append(res, v) + } + return res +} + func (cm *ConfigLoader) ListConfigs() []string { cm.Lock() defer cm.Unlock() diff --git a/api/openai/list.go b/api/openai/list.go index 0cd7f3af..59159921 100644 --- a/api/openai/list.go +++ b/api/openai/list.go @@ -1,6 +1,8 @@ package openai import ( + "regexp" + config "github.com/go-skynet/LocalAI/api/config" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -15,14 +17,43 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func var mm map[string]interface{} = map[string]interface{}{} dataModels := []OpenAIModel{} - for _, m := range models { - mm[m] = nil - dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) + + var filterFn func(name string) bool + filter := c.Query("filter") + + // If filter is not specified, do not filter the list by model name + if filter == "" { + filterFn = func(_ string) bool { return true } + } else { + // If filter _IS_ specified, we compile it to a regex which is used to create the filterFn + rxp, err := regexp.Compile(filter) + if err != nil { + return err + } + filterFn = func(name string) bool { + return rxp.MatchString(name) + } } - for _, k := range cm.ListConfigs() { - if _, exists := mm[k]; !exists { - dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) + // By default, exclude any loose files that are already referenced by a configuration file. + excludeConfigured := c.QueryBool("excludeConfigured", true) + + // Start with the known configurations + for _, c := range cm.GetAllConfigs() { + if excludeConfigured { + mm[c.Model] = nil + } + + if filterFn(c.Name) { + dataModels = append(dataModels, OpenAIModel{ID: c.Name, Object: "model"}) + } + } + + // Then iterate through the loose files: + for _, m := range models { + // And only adds them if they shouldn't be skipped. + if _, exists := mm[m]; !exists && filterFn(m) { + dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) } } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index daadc969..73d13ebd 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -85,8 +85,8 @@ func (ml *ModelLoader) ListModels() ([]string, error) { models := []string{} for _, file := range files { - // Skip templates, YAML and .keep files - if strings.HasSuffix(file.Name(), ".tmpl") || strings.HasSuffix(file.Name(), ".keep") || strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") { + // Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method? + if strings.HasSuffix(file.Name(), ".tmpl") || strings.HasSuffix(file.Name(), ".keep") || strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") || strings.HasSuffix(file.Name(), ".json") || strings.HasSuffix(file.Name(), ".DS_Store") { continue }