From e73283121bece33d40d080edaa7bd9d3c88d7486 Mon Sep 17 00:00:00 2001 From: mudler Date: Fri, 5 May 2023 15:54:59 +0200 Subject: [PATCH] feat: support arrays for prompt and input Signed-off-by: mudler --- api/config.go | 2 + api/openai.go | 101 ++++++++++++++++++++++++++++------------------ api/prediction.go | 1 + 3 files changed, 64 insertions(+), 40 deletions(-) diff --git a/api/config.go b/api/config.go index 8e550e17..d5df3de1 100644 --- a/api/config.go +++ b/api/config.go @@ -27,6 +27,8 @@ type Config struct { MirostatETA float64 `yaml:"mirostat_eta"` MirostatTAU float64 `yaml:"mirostat_tau"` Mirostat int `yaml:"mirostat"` + + PromptStrings, InputStrings []string } type TemplateConfig struct { diff --git a/api/openai.go b/api/openai.go index fc982f25..3a6b947d 100644 --- a/api/openai.go +++ b/api/openai.go @@ -75,8 +75,8 @@ type OpenAIRequest struct { Prompt interface{} `json:"prompt" yaml:"prompt"` // Edit endpoint - Instruction string `json:"instruction" yaml:"instruction"` - Input string `json:"input" yaml:"input"` + Instruction string `json:"instruction" yaml:"instruction"` + Input interface{} `json:"input" yaml:"input"` Stop interface{} `json:"stop" yaml:"stop"` @@ -184,6 +184,30 @@ func updateConfig(config *Config, input *OpenAIRequest) { if input.MirostatTAU != 0 { config.MirostatTAU = input.MirostatTAU } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + if s, ok := pp.(string); ok { + config.InputStrings = append(config.InputStrings, s) + } + } + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } } func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { @@ -268,19 +292,6 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, log.Debug().Msgf("Parameter Config: %+v", config) - predInput := []string{} - - switch p := input.Prompt.(type) { - case string: - predInput = append(predInput, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - predInput = append(predInput, s) - } - } - } - templateFile := config.Model if config.TemplateConfig.Completion != "" { @@ -288,7 +299,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } var result []Choice - for _, i := range predInput { + for _, i := range config.PromptStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix templatedInput, err := loader.TemplatePrefix(templateFile, struct { Input string @@ -331,20 +342,26 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } log.Debug().Msgf("Parameter Config: %+v", config) + items := []Item{} - // get the model function to call for the result - embedFn, err := ModelEmbedding(input.Input, loader, *config) - if err != nil { - return err + for i, s := range config.InputStrings { + + // get the model function to call for the result + embedFn, err := ModelEmbedding(s, loader, *config) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) } - embeddings, err := embedFn() - if err != nil { - return err - } resp := &OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: []Item{{Embedding: embeddings, Index: 0, Object: "embedding"}}, + Data: items, Object: "list", } @@ -480,28 +497,32 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread log.Debug().Msgf("Parameter Config: %+v", config) - predInput := input.Input templateFile := config.Model if config.TemplateConfig.Edit != "" { templateFile = config.TemplateConfig.Edit } - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile, struct { - Input string - Instruction string - }{Input: predInput, Instruction: input.Instruction}) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } + var result []Choice + for _, i := range config.InputStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile, struct { + Input string + Instruction string + }{Input: i}) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } - result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err + r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) } resp := &OpenAIResponse{ diff --git a/api/prediction.go b/api/prediction.go index 45db078a..009641a2 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -28,6 +28,7 @@ func defaultLLamaOpts(c Config) []llama.ModelOption { if c.Embeddings { llamaOpts = append(llamaOpts, llama.EnableEmbeddings) } + return llamaOpts }