feat: support slices or strings in the prompt completion endpoint (#162)

Signed-off-by: mudler <mudler@mocaccino.org>
This commit is contained in:
Ettore Di Giacinto 2023-05-03 13:13:31 +02:00 committed by GitHub
parent 0a4899f366
commit 67992a7d99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -57,7 +57,7 @@ type OpenAIRequest struct {
Model string `json:"model" yaml:"model"`
// Prompt is read only by completion API calls
Prompt string `json:"prompt" yaml:"prompt"`
Prompt interface{} `json:"prompt" yaml:"prompt"`
// Edit endpoint
Instruction string `json:"instruction" yaml:"instruction"`
@ -122,9 +122,12 @@ func updateConfig(config *Config, input *OpenAIRequest) {
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []string:
config.StopWords = append(config.StopWords, stop...)
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if input.RepeatPenalty != 0 {
@ -234,27 +237,44 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Parameter Config: %+v", config)
predInput := input.Prompt
predInput := []string{}
switch p := input.Prompt.(type) {
case string:
predInput = append(predInput, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
predInput = append(predInput, s)
}
}
}
templateFile := config.Model
if config.TemplateConfig.Completion != "" {
templateFile = config.TemplateConfig.Completion
}
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: predInput})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
var result []Choice
for _, i := range predInput {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: i})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
}, nil)
if err != nil {
return err
r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
}, nil)
if err != nil {
return err
}
result = append(result, r...)
}
resp := &OpenAIResponse{