2023-04-27 04:18:18 +00:00
|
|
|
package api
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
2023-05-27 12:29:11 +00:00
|
|
|
"os"
|
|
|
|
"path/filepath"
|
2023-04-29 07:22:09 +00:00
|
|
|
"regexp"
|
|
|
|
"strings"
|
2023-04-27 04:18:18 +00:00
|
|
|
"sync"
|
|
|
|
|
2023-05-03 09:45:22 +00:00
|
|
|
"github.com/donomii/go-rwkv.cpp"
|
2023-04-27 04:18:18 +00:00
|
|
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
2023-05-16 17:32:53 +00:00
|
|
|
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
|
2023-05-10 23:12:58 +00:00
|
|
|
"github.com/go-skynet/bloomz.cpp"
|
2023-05-10 13:20:21 +00:00
|
|
|
bert "github.com/go-skynet/go-bert.cpp"
|
2023-05-23 19:47:47 +00:00
|
|
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
2023-04-27 04:18:18 +00:00
|
|
|
llama "github.com/go-skynet/go-llama.cpp"
|
2023-05-16 17:32:53 +00:00
|
|
|
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
2023-04-27 04:18:18 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
var mutexMap sync.Mutex
|
|
|
|
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
|
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
func defaultLLamaOpts(c Config) []llama.ModelOption {
|
|
|
|
llamaOpts := []llama.ModelOption{}
|
|
|
|
if c.ContextSize != 0 {
|
|
|
|
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize))
|
2023-05-02 22:31:28 +00:00
|
|
|
}
|
2023-05-05 09:20:06 +00:00
|
|
|
if c.F16 {
|
|
|
|
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
2023-05-02 22:31:28 +00:00
|
|
|
}
|
2023-05-05 09:20:06 +00:00
|
|
|
if c.Embeddings {
|
|
|
|
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
2023-05-02 22:31:28 +00:00
|
|
|
}
|
2023-05-05 13:54:59 +00:00
|
|
|
|
2023-05-16 14:26:25 +00:00
|
|
|
if c.NGPULayers != 0 {
|
|
|
|
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers))
|
|
|
|
}
|
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
return llamaOpts
|
|
|
|
}
|
2023-05-02 22:31:28 +00:00
|
|
|
|
2023-05-16 17:32:53 +00:00
|
|
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config) (func() error, error) {
|
|
|
|
if c.Backend != model.StableDiffusionBackend {
|
|
|
|
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
|
|
|
}
|
|
|
|
inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
var fn func() error
|
|
|
|
switch model := inferenceModel.(type) {
|
|
|
|
case *stablediffusion.StableDiffusion:
|
|
|
|
fn = func() error {
|
|
|
|
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst)
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
fn = func() error {
|
|
|
|
return fmt.Errorf("creation of images not supported by the backend")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return func() error {
|
|
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
mutexMap.Lock()
|
|
|
|
l, ok := mutexes[c.Backend]
|
|
|
|
if !ok {
|
|
|
|
m := &sync.Mutex{}
|
|
|
|
mutexes[c.Backend] = m
|
|
|
|
l = m
|
|
|
|
}
|
|
|
|
mutexMap.Unlock()
|
|
|
|
l.Lock()
|
|
|
|
defer l.Unlock()
|
|
|
|
|
|
|
|
return fn()
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
2023-05-08 17:31:18 +00:00
|
|
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
|
2023-05-05 09:20:06 +00:00
|
|
|
if !c.Embeddings {
|
|
|
|
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
2023-05-02 22:31:28 +00:00
|
|
|
}
|
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
modelFile := c.Model
|
2023-05-02 22:31:28 +00:00
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
llamaOpts := defaultLLamaOpts(c)
|
2023-05-02 22:31:28 +00:00
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
var inferenceModel interface{}
|
|
|
|
var err error
|
|
|
|
if c.Backend == "" {
|
|
|
|
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads))
|
2023-05-02 22:31:28 +00:00
|
|
|
} else {
|
2023-05-05 09:20:06 +00:00
|
|
|
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads))
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2023-05-02 22:31:28 +00:00
|
|
|
}
|
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
var fn func() ([]float32, error)
|
|
|
|
switch model := inferenceModel.(type) {
|
|
|
|
case *llama.LLama:
|
|
|
|
fn = func() ([]float32, error) {
|
2023-05-27 12:29:11 +00:00
|
|
|
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
|
2023-05-08 17:31:18 +00:00
|
|
|
if len(tokens) > 0 {
|
|
|
|
return model.TokenEmbeddings(tokens, predictOptions...)
|
|
|
|
}
|
2023-05-05 13:56:02 +00:00
|
|
|
return model.Embeddings(s, predictOptions...)
|
2023-05-05 09:20:06 +00:00
|
|
|
}
|
2023-05-10 13:20:21 +00:00
|
|
|
// bert embeddings
|
|
|
|
case *bert.Bert:
|
|
|
|
fn = func() ([]float32, error) {
|
|
|
|
if len(tokens) > 0 {
|
2023-05-12 15:16:49 +00:00
|
|
|
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
|
2023-05-10 13:20:21 +00:00
|
|
|
}
|
|
|
|
return model.Embeddings(s, bert.SetThreads(c.Threads))
|
|
|
|
}
|
2023-05-05 09:20:06 +00:00
|
|
|
default:
|
|
|
|
fn = func() ([]float32, error) {
|
|
|
|
return nil, fmt.Errorf("embeddings not supported by the backend")
|
|
|
|
}
|
2023-05-03 09:45:22 +00:00
|
|
|
}
|
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
return func() ([]float32, error) {
|
|
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
mutexMap.Lock()
|
|
|
|
l, ok := mutexes[modelFile]
|
|
|
|
if !ok {
|
|
|
|
m := &sync.Mutex{}
|
|
|
|
mutexes[modelFile] = m
|
|
|
|
l = m
|
|
|
|
}
|
|
|
|
mutexMap.Unlock()
|
|
|
|
l.Lock()
|
|
|
|
defer l.Unlock()
|
|
|
|
|
2023-05-05 16:05:29 +00:00
|
|
|
embeds, err := fn()
|
|
|
|
if err != nil {
|
|
|
|
return embeds, err
|
|
|
|
}
|
|
|
|
// Remove trailing 0s
|
|
|
|
for i := len(embeds) - 1; i >= 0; i-- {
|
|
|
|
if embeds[i] == 0.0 {
|
|
|
|
embeds = embeds[:i]
|
|
|
|
} else {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return embeds, nil
|
2023-05-05 09:20:06 +00:00
|
|
|
}, nil
|
2023-05-02 22:31:28 +00:00
|
|
|
}
|
|
|
|
|
2023-05-27 12:29:11 +00:00
|
|
|
func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption {
|
2023-05-05 13:56:02 +00:00
|
|
|
// Generate the prediction using the language model
|
|
|
|
predictOptions := []llama.PredictOption{
|
|
|
|
llama.SetTemperature(c.Temperature),
|
|
|
|
llama.SetTopP(c.TopP),
|
|
|
|
llama.SetTopK(c.TopK),
|
|
|
|
llama.SetTokens(c.Maxtokens),
|
|
|
|
llama.SetThreads(c.Threads),
|
|
|
|
}
|
|
|
|
|
2023-05-27 12:29:11 +00:00
|
|
|
if c.PromptCacheAll {
|
|
|
|
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.PromptCachePath != "" {
|
|
|
|
// Create parent directory
|
|
|
|
p := filepath.Join(modelPath, c.PromptCachePath)
|
|
|
|
os.MkdirAll(filepath.Dir(p), 0755)
|
|
|
|
predictOptions = append(predictOptions, llama.SetPathPromptCache(p))
|
|
|
|
}
|
|
|
|
|
2023-05-05 13:56:02 +00:00
|
|
|
if c.Mirostat != 0 {
|
|
|
|
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat))
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.MirostatETA != 0 {
|
|
|
|
predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA))
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.MirostatTAU != 0 {
|
|
|
|
predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU))
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.Debug {
|
|
|
|
predictOptions = append(predictOptions, llama.Debug)
|
|
|
|
}
|
|
|
|
|
|
|
|
predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...))
|
|
|
|
|
|
|
|
if c.RepeatPenalty != 0 {
|
|
|
|
predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty))
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.Keep != 0 {
|
|
|
|
predictOptions = append(predictOptions, llama.SetNKeep(c.Keep))
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
|
|
|
predictOptions = append(predictOptions, llama.SetBatch(c.Batch))
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.F16 {
|
|
|
|
predictOptions = append(predictOptions, llama.EnableF16KV)
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.IgnoreEOS {
|
|
|
|
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
|
|
|
predictOptions = append(predictOptions, llama.SetSeed(c.Seed))
|
|
|
|
}
|
|
|
|
|
|
|
|
return predictOptions
|
|
|
|
}
|
|
|
|
|
2023-05-02 18:03:35 +00:00
|
|
|
func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) {
|
|
|
|
supportStreams := false
|
2023-04-27 04:18:18 +00:00
|
|
|
modelFile := c.Model
|
|
|
|
|
2023-05-05 09:20:06 +00:00
|
|
|
llamaOpts := defaultLLamaOpts(c)
|
2023-04-27 04:18:18 +00:00
|
|
|
|
2023-05-02 22:31:28 +00:00
|
|
|
var inferenceModel interface{}
|
|
|
|
var err error
|
|
|
|
if c.Backend == "" {
|
2023-05-05 09:20:06 +00:00
|
|
|
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads))
|
2023-05-02 22:31:28 +00:00
|
|
|
} else {
|
2023-05-05 09:20:06 +00:00
|
|
|
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads))
|
2023-05-02 22:31:28 +00:00
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
var fn func() (string, error)
|
|
|
|
|
2023-05-02 22:31:28 +00:00
|
|
|
switch model := inferenceModel.(type) {
|
2023-05-03 09:45:22 +00:00
|
|
|
case *rwkv.RwkvState:
|
|
|
|
supportStreams = true
|
|
|
|
|
|
|
|
fn = func() (string, error) {
|
|
|
|
stopWord := "\n"
|
|
|
|
if len(c.StopWords) > 0 {
|
|
|
|
stopWord = c.StopWords[0]
|
|
|
|
}
|
|
|
|
|
2023-05-04 15:32:23 +00:00
|
|
|
if err := model.ProcessInput(s); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
2023-05-03 09:45:22 +00:00
|
|
|
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
|
|
|
|
|
|
|
|
return response, nil
|
|
|
|
}
|
2023-05-23 19:47:47 +00:00
|
|
|
case *transformers.GPTNeoX:
|
2023-05-12 09:36:35 +00:00
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions := []transformers.PredictOption{
|
|
|
|
transformers.SetTemperature(c.Temperature),
|
|
|
|
transformers.SetTopP(c.TopP),
|
|
|
|
transformers.SetTopK(c.TopK),
|
|
|
|
transformers.SetTokens(c.Maxtokens),
|
|
|
|
transformers.SetThreads(c.Threads),
|
2023-05-12 09:36:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
2023-05-12 09:36:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
2023-05-12 09:36:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return model.Predict(
|
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
2023-05-23 19:47:47 +00:00
|
|
|
case *transformers.Replit:
|
2023-05-12 09:36:35 +00:00
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions := []transformers.PredictOption{
|
|
|
|
transformers.SetTemperature(c.Temperature),
|
|
|
|
transformers.SetTopP(c.TopP),
|
|
|
|
transformers.SetTopK(c.TopK),
|
|
|
|
transformers.SetTokens(c.Maxtokens),
|
|
|
|
transformers.SetThreads(c.Threads),
|
2023-05-12 09:36:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
2023-05-12 09:36:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
2023-05-12 09:36:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return model.Predict(
|
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
2023-05-23 19:47:47 +00:00
|
|
|
case *transformers.Starcoder:
|
2023-05-11 18:20:07 +00:00
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions := []transformers.PredictOption{
|
|
|
|
transformers.SetTemperature(c.Temperature),
|
|
|
|
transformers.SetTopP(c.TopP),
|
|
|
|
transformers.SetTopK(c.TopK),
|
|
|
|
transformers.SetTokens(c.Maxtokens),
|
|
|
|
transformers.SetThreads(c.Threads),
|
2023-05-11 18:20:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
2023-05-11 18:20:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
2023-05-11 18:20:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return model.Predict(
|
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
2023-05-23 19:47:47 +00:00
|
|
|
case *transformers.MPT:
|
2023-05-10 23:12:58 +00:00
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions := []transformers.PredictOption{
|
|
|
|
transformers.SetTemperature(c.Temperature),
|
|
|
|
transformers.SetTopP(c.TopP),
|
|
|
|
transformers.SetTopK(c.TopK),
|
|
|
|
transformers.SetTokens(c.Maxtokens),
|
|
|
|
transformers.SetThreads(c.Threads),
|
2023-05-10 23:12:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
2023-05-10 23:12:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
2023-05-10 23:12:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return model.Predict(
|
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
|
|
|
case *bloomz.Bloomz:
|
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
|
|
|
predictOptions := []bloomz.PredictOption{
|
|
|
|
bloomz.SetTemperature(c.Temperature),
|
|
|
|
bloomz.SetTopP(c.TopP),
|
|
|
|
bloomz.SetTopK(c.TopK),
|
|
|
|
bloomz.SetTokens(c.Maxtokens),
|
|
|
|
bloomz.SetThreads(c.Threads),
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
|
|
|
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
|
|
|
|
}
|
|
|
|
|
|
|
|
return model.Predict(
|
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
2023-05-23 19:47:47 +00:00
|
|
|
case *transformers.GPTJ:
|
2023-04-27 04:18:18 +00:00
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions := []transformers.PredictOption{
|
|
|
|
transformers.SetTemperature(c.Temperature),
|
|
|
|
transformers.SetTopP(c.TopP),
|
|
|
|
transformers.SetTopK(c.TopK),
|
|
|
|
transformers.SetTokens(c.Maxtokens),
|
|
|
|
transformers.SetThreads(c.Threads),
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
2023-05-02 22:31:28 +00:00
|
|
|
return model.Predict(
|
2023-04-27 04:18:18 +00:00
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
2023-05-23 19:47:47 +00:00
|
|
|
case *transformers.Dolly:
|
2023-05-10 23:12:58 +00:00
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions := []transformers.PredictOption{
|
|
|
|
transformers.SetTemperature(c.Temperature),
|
|
|
|
transformers.SetTopP(c.TopP),
|
|
|
|
transformers.SetTopK(c.TopK),
|
|
|
|
transformers.SetTokens(c.Maxtokens),
|
|
|
|
transformers.SetThreads(c.Threads),
|
2023-05-10 23:12:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
2023-05-10 23:12:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
2023-05-10 23:12:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return model.Predict(
|
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
2023-05-23 19:47:47 +00:00
|
|
|
case *transformers.GPT2:
|
2023-04-27 04:18:18 +00:00
|
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions := []transformers.PredictOption{
|
|
|
|
transformers.SetTemperature(c.Temperature),
|
|
|
|
transformers.SetTopP(c.TopP),
|
|
|
|
transformers.SetTopK(c.TopK),
|
|
|
|
transformers.SetTokens(c.Maxtokens),
|
|
|
|
transformers.SetThreads(c.Threads),
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Batch != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if c.Seed != 0 {
|
2023-05-23 19:47:47 +00:00
|
|
|
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
2023-05-02 22:31:28 +00:00
|
|
|
return model.Predict(
|
2023-04-27 04:18:18 +00:00
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
|
|
|
}
|
2023-05-11 12:31:19 +00:00
|
|
|
case *gpt4all.Model:
|
|
|
|
supportStreams = true
|
|
|
|
|
2023-04-27 04:18:18 +00:00
|
|
|
fn = func() (string, error) {
|
2023-05-11 12:31:19 +00:00
|
|
|
if tokenCallback != nil {
|
|
|
|
model.SetTokenCallback(tokenCallback)
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
2023-05-11 12:31:19 +00:00
|
|
|
// Generate the prediction using the language model
|
|
|
|
predictOptions := []gpt4all.PredictOption{
|
|
|
|
gpt4all.SetTemperature(c.Temperature),
|
|
|
|
gpt4all.SetTopP(c.TopP),
|
|
|
|
gpt4all.SetTopK(c.TopK),
|
|
|
|
gpt4all.SetTokens(c.Maxtokens),
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
|
2023-05-11 12:31:19 +00:00
|
|
|
if c.Batch != 0 {
|
|
|
|
predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch))
|
2023-05-02 14:07:18 +00:00
|
|
|
}
|
|
|
|
|
2023-05-11 12:31:19 +00:00
|
|
|
str, er := model.Predict(
|
2023-04-27 04:18:18 +00:00
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
2023-05-11 12:31:19 +00:00
|
|
|
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
|
|
|
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
|
|
|
// after a stream event has occurred
|
|
|
|
model.SetTokenCallback(nil)
|
|
|
|
return str, er
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
2023-05-02 22:31:28 +00:00
|
|
|
case *llama.LLama:
|
2023-05-02 18:03:35 +00:00
|
|
|
supportStreams = true
|
2023-04-27 04:18:18 +00:00
|
|
|
fn = func() (string, error) {
|
2023-05-02 18:03:35 +00:00
|
|
|
|
|
|
|
if tokenCallback != nil {
|
|
|
|
model.SetTokenCallback(tokenCallback)
|
|
|
|
}
|
|
|
|
|
2023-05-27 12:29:11 +00:00
|
|
|
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
|
2023-04-27 04:18:18 +00:00
|
|
|
|
2023-05-04 17:49:43 +00:00
|
|
|
str, er := model.Predict(
|
2023-04-27 04:18:18 +00:00
|
|
|
s,
|
|
|
|
predictOptions...,
|
|
|
|
)
|
2023-05-04 17:49:43 +00:00
|
|
|
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
|
|
|
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
|
|
|
// after a stream event has occurred
|
|
|
|
model.SetTokenCallback(nil)
|
|
|
|
return str, er
|
2023-04-27 04:18:18 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return func() (string, error) {
|
|
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
mutexMap.Lock()
|
|
|
|
l, ok := mutexes[modelFile]
|
|
|
|
if !ok {
|
|
|
|
m := &sync.Mutex{}
|
|
|
|
mutexes[modelFile] = m
|
|
|
|
l = m
|
|
|
|
}
|
|
|
|
mutexMap.Unlock()
|
|
|
|
l.Lock()
|
|
|
|
defer l.Unlock()
|
|
|
|
|
2023-05-02 18:03:35 +00:00
|
|
|
res, err := fn()
|
|
|
|
if tokenCallback != nil && !supportStreams {
|
|
|
|
tokenCallback(res)
|
|
|
|
}
|
|
|
|
return res, err
|
2023-04-27 04:18:18 +00:00
|
|
|
}, nil
|
|
|
|
}
|
2023-04-29 07:22:09 +00:00
|
|
|
|
2023-05-02 18:03:35 +00:00
|
|
|
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
2023-04-29 07:22:09 +00:00
|
|
|
result := []Choice{}
|
|
|
|
|
|
|
|
n := input.N
|
|
|
|
|
|
|
|
if input.N == 0 {
|
|
|
|
n = 1
|
|
|
|
}
|
|
|
|
|
|
|
|
// get the model function to call for the result
|
2023-05-02 18:03:35 +00:00
|
|
|
predFunc, err := ModelInference(predInput, loader, *config, tokenCallback)
|
2023-04-29 07:22:09 +00:00
|
|
|
if err != nil {
|
|
|
|
return result, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := 0; i < n; i++ {
|
|
|
|
prediction, err := predFunc()
|
|
|
|
if err != nil {
|
|
|
|
return result, err
|
|
|
|
}
|
|
|
|
|
|
|
|
prediction = Finetune(*config, predInput, prediction)
|
|
|
|
cb(prediction, &result)
|
|
|
|
|
|
|
|
//result = append(result, Choice{Text: prediction})
|
|
|
|
|
|
|
|
}
|
|
|
|
return result, err
|
|
|
|
}
|
|
|
|
|
|
|
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
|
|
|
var mu sync.Mutex = sync.Mutex{}
|
|
|
|
|
|
|
|
func Finetune(config Config, input, prediction string) string {
|
|
|
|
if config.Echo {
|
|
|
|
prediction = input + prediction
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, c := range config.Cutstrings {
|
|
|
|
mu.Lock()
|
|
|
|
reg, ok := cutstrings[c]
|
|
|
|
if !ok {
|
|
|
|
cutstrings[c] = regexp.MustCompile(c)
|
|
|
|
reg = cutstrings[c]
|
|
|
|
}
|
|
|
|
mu.Unlock()
|
|
|
|
prediction = reg.ReplaceAllString(prediction, "")
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, c := range config.TrimSpace {
|
|
|
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
|
|
|
}
|
|
|
|
return prediction
|
|
|
|
|
|
|
|
}
|