feature: model list :: filter query string parameter (#830)

This commit is contained in:
Dave 2023-07-31 13:14:32 -04:00 committed by GitHub
parent 32ca7efbeb
commit ce8e9dc690
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 18 deletions

View File

@ -534,7 +534,7 @@ var _ = Describe("API test", func() {
It("returns the models list", func() { It("returns the models list", func() {
models, err := client.ListModels(context.TODO()) models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(11)) Expect(len(models.Models)).To(Equal(6)) // If "config.yaml" should be included, this should be 8?
}) })
It("can generate completions", func() { It("can generate completions", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"}) resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
@ -738,19 +738,14 @@ var _ = Describe("API test", func() {
cancel() cancel()
app.Shutdown() app.Shutdown()
}) })
It("can generate chat completions from config file", func() { It("can generate chat completions from config file (list1)", func() {
models, err := client.ListModels(context.TODO()) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(13))
})
It("can generate chat completions from config file", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1)) Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
}) })
It("can generate chat completions from config file", func() { It("can generate chat completions from config file (list2)", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1)) Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())

View File

@ -172,6 +172,16 @@ func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
return v, exists return v, exists
} }
func (cm *ConfigLoader) GetAllConfigs() []Config {
cm.Lock()
defer cm.Unlock()
var res []Config
for _, v := range cm.configs {
res = append(res, v)
}
return res
}
func (cm *ConfigLoader) ListConfigs() []string { func (cm *ConfigLoader) ListConfigs() []string {
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()

View File

@ -1,6 +1,8 @@
package openai package openai
import ( import (
"regexp"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -15,14 +17,43 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
var mm map[string]interface{} = map[string]interface{}{} var mm map[string]interface{} = map[string]interface{}{}
dataModels := []OpenAIModel{} dataModels := []OpenAIModel{}
for _, m := range models {
mm[m] = nil var filterFn func(name string) bool
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) filter := c.Query("filter")
// If filter is not specified, do not filter the list by model name
if filter == "" {
filterFn = func(_ string) bool { return true }
} else {
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
rxp, err := regexp.Compile(filter)
if err != nil {
return err
}
filterFn = func(name string) bool {
return rxp.MatchString(name)
}
} }
for _, k := range cm.ListConfigs() { // By default, exclude any loose files that are already referenced by a configuration file.
if _, exists := mm[k]; !exists { excludeConfigured := c.QueryBool("excludeConfigured", true)
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
// Start with the known configurations
for _, c := range cm.GetAllConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}
if filterFn(c.Name) {
dataModels = append(dataModels, OpenAIModel{ID: c.Name, Object: "model"})
}
}
// Then iterate through the loose files:
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
} }
} }

View File

@ -85,8 +85,8 @@ func (ml *ModelLoader) ListModels() ([]string, error) {
models := []string{} models := []string{}
for _, file := range files { for _, file := range files {
// Skip templates, YAML and .keep files // Skip templates, YAML, .keep, .json, and .DS_Store files - TODO: as this list grows, is there a more efficient method?
if strings.HasSuffix(file.Name(), ".tmpl") || strings.HasSuffix(file.Name(), ".keep") || strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") { if strings.HasSuffix(file.Name(), ".tmpl") || strings.HasSuffix(file.Name(), ".keep") || strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") || strings.HasSuffix(file.Name(), ".json") || strings.HasSuffix(file.Name(), ".DS_Store") {
continue continue
} }