LocalAI/api/openai/inference.go

51 lines
1.3 KiB
Go
Raw Normal View History

package openai
import (
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(
req *schema.OpenAIRequest,
predInput string,
config *config.Config,
o *options.Option,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
2023-08-18 19:23:14 +00:00
n := req.N // number of completions to return
result := []schema.Choice{}
if n == 0 {
n = 1
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
if err != nil {
2023-08-18 19:23:14 +00:00
return result, backend.TokenUsage{}, err
}
2023-08-18 19:23:14 +00:00
tokenUsage := backend.TokenUsage{}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
2023-08-18 19:23:14 +00:00
return result, backend.TokenUsage{}, err
}
2023-08-18 19:23:14 +00:00
tokenUsage.Prompt += prediction.Usage.Prompt
tokenUsage.Completion += prediction.Usage.Completion
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
//result = append(result, Choice{Text: prediction})
}
2023-08-18 19:23:14 +00:00
return result, tokenUsage, err
}