LocalAI/api/openai/inference.go

43 lines
1.2 KiB
Go
Raw Normal View History

package openai
import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
model "github.com/go-skynet/LocalAI/pkg/model"
)
2023-08-18 19:23:14 +00:00
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string, backend.TokenUsage) bool) ([]Choice, backend.TokenUsage, error) {
n := req.N // number of completions to return
result := []Choice{}
if n == 0 {
n = 1
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
if err != nil {
2023-08-18 19:23:14 +00:00
return result, backend.TokenUsage{}, err
}
2023-08-18 19:23:14 +00:00
tokenUsage := backend.TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
2023-08-18 19:23:14 +00:00
return result, backend.TokenUsage{}, err
}
2023-08-18 19:23:14 +00:00
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
2023-08-18 19:23:14 +00:00
return result, tokenUsage, err
}