feat: support arrays for prompt and input

Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
mudler 2023-05-05 15:54:59 +02:00
parent 857d13e8d6
commit e73283121b
3 changed files with 64 additions and 40 deletions

View File

@ -27,6 +27,8 @@ type Config struct {
MirostatETA float64 `yaml:"mirostat_eta"` MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"` MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"` Mirostat int `yaml:"mirostat"`
PromptStrings, InputStrings []string
} }
type TemplateConfig struct { type TemplateConfig struct {

View File

@ -76,7 +76,7 @@ type OpenAIRequest struct {
// Edit endpoint // Edit endpoint
Instruction string `json:"instruction" yaml:"instruction"` Instruction string `json:"instruction" yaml:"instruction"`
Input string `json:"input" yaml:"input"` Input interface{} `json:"input" yaml:"input"`
Stop interface{} `json:"stop" yaml:"stop"` Stop interface{} `json:"stop" yaml:"stop"`
@ -184,6 +184,30 @@ func updateConfig(config *Config, input *OpenAIRequest) {
if input.MirostatTAU != 0 { if input.MirostatTAU != 0 {
config.MirostatTAU = input.MirostatTAU config.MirostatTAU = input.MirostatTAU
} }
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
if s, ok := pp.(string); ok {
config.InputStrings = append(config.InputStrings, s)
}
}
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
} }
func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
@ -268,19 +292,6 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)
predInput := []string{}
switch p := input.Prompt.(type) {
case string:
predInput = append(predInput, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
predInput = append(predInput, s)
}
}
}
templateFile := config.Model templateFile := config.Model
if config.TemplateConfig.Completion != "" { if config.TemplateConfig.Completion != "" {
@ -288,7 +299,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
} }
var result []Choice var result []Choice
for _, i := range predInput { for _, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct { templatedInput, err := loader.TemplatePrefix(templateFile, struct {
Input string Input string
@ -331,9 +342,12 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
} }
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)
items := []Item{}
for i, s := range config.InputStrings {
// get the model function to call for the result // get the model function to call for the result
embedFn, err := ModelEmbedding(input.Input, loader, *config) embedFn, err := ModelEmbedding(s, loader, *config)
if err != nil { if err != nil {
return err return err
} }
@ -342,9 +356,12 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
if err != nil { if err != nil {
return err return err
} }
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
resp := &OpenAIResponse{ resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: []Item{{Embedding: embeddings, Index: 0, Object: "embedding"}}, Data: items,
Object: "list", Object: "list",
} }
@ -480,30 +497,34 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)
predInput := input.Input
templateFile := config.Model templateFile := config.Model
if config.TemplateConfig.Edit != "" { if config.TemplateConfig.Edit != "" {
templateFile = config.TemplateConfig.Edit templateFile = config.TemplateConfig.Edit
} }
var result []Choice
for _, i := range config.InputStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct { templatedInput, err := loader.TemplatePrefix(templateFile, struct {
Input string Input string
Instruction string Instruction string
}{Input: predInput, Instruction: input.Instruction}) }{Input: i})
if err == nil { if err == nil {
predInput = templatedInput i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput) log.Debug().Msgf("Template found, input modified to: %s", i)
} }
result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s}) *c = append(*c, Choice{Text: s})
}, nil) }, nil)
if err != nil { if err != nil {
return err return err
} }
result = append(result, r...)
}
resp := &OpenAIResponse{ resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result, Choices: result,

View File

@ -28,6 +28,7 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
if c.Embeddings { if c.Embeddings {
llamaOpts = append(llamaOpts, llama.EnableEmbeddings) llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
} }
return llamaOpts return llamaOpts
} }