2023-07-14 23:19:43 +00:00
|
|
|
package openai
|
|
|
|
|
|
|
|
import (
|
|
|
|
"github.com/go-skynet/LocalAI/api/backend"
|
|
|
|
config "github.com/go-skynet/LocalAI/api/config"
|
|
|
|
"github.com/go-skynet/LocalAI/api/options"
|
|
|
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
|
|
)
|
|
|
|
|
2023-07-24 21:10:54 +00:00
|
|
|
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
|
|
|
n := req.N
|
2023-07-14 23:19:43 +00:00
|
|
|
result := []Choice{}
|
|
|
|
|
|
|
|
if n == 0 {
|
|
|
|
n = 1
|
|
|
|
}
|
|
|
|
|
|
|
|
// get the model function to call for the result
|
2023-07-24 21:10:54 +00:00
|
|
|
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
|
2023-07-14 23:19:43 +00:00
|
|
|
if err != nil {
|
|
|
|
return result, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := 0; i < n; i++ {
|
|
|
|
prediction, err := predFunc()
|
|
|
|
if err != nil {
|
|
|
|
return result, err
|
|
|
|
}
|
|
|
|
|
|
|
|
prediction = backend.Finetune(*config, predInput, prediction)
|
|
|
|
cb(prediction, &result)
|
|
|
|
|
|
|
|
//result = append(result, Choice{Text: prediction})
|
|
|
|
|
|
|
|
}
|
|
|
|
return result, err
|
|
|
|
}
|