Allow to template model prompts inputs

This commit is contained in:
mudler 2023-04-08 10:46:51 +02:00
parent 48aca246e3
commit 9fb581739b
2 changed files with 59 additions and 8 deletions

12
api.go
View File

@ -103,10 +103,18 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre
mess = append(mess, i.Content)
}
fmt.Println("Received", input, input.Model)
predInput := strings.Join(mess, "\n")
templatedInput, err := loader.TemplatePrefix(input.Model, struct {
Input string
}{Input: predInput})
if err == nil {
predInput = templatedInput
}
// Generate the prediction using the language model
prediction, err := model.Predict(
strings.Join(mess, "\n"),
templatedInput,
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),

View File

@ -1,30 +1,55 @@
package main
import (
"bytes"
"fmt"
"os"
"path/filepath"
"sync"
"text/template"
llama "github.com/go-skynet/go-llama.cpp"
)
type ModelLoader struct {
modelPath string
mu sync.Mutex
models map[string]*llama.LLama
modelPath string
mu sync.Mutex
models map[string]*llama.LLama
promptsTemplates map[string]*template.Template
}
func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama)}
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
}
func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LLama, error) {
func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
m, ok := ml.promptsTemplates[modelName]
if !ok {
// try to find a s.bin
modelBin := fmt.Sprintf("%s.bin", modelName)
m, ok = ml.promptsTemplates[modelBin]
if !ok {
return "", fmt.Errorf("no prompt template available")
}
}
var buf bytes.Buffer
if err := m.Execute(&buf, in); err != nil {
return "", err
}
return buf.String(), nil
}
func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
modelFile := filepath.Join(ml.modelPath, s)
modelFile := filepath.Join(ml.modelPath, modelName)
if m, ok := ml.models[modelFile]; ok {
return m, nil
@ -47,6 +72,24 @@ func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LL
return nil, err
}
// If there is a prompt template, load it
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile)
// Check if the model path exists
if _, err := os.Stat(modelTemplateFile); err == nil {
dat, err := os.ReadFile(modelTemplateFile)
if err != nil {
return nil, err
}
// Parse the template
tmpl, err := template.New("prompt").Parse(string(dat))
if err != nil {
return nil, err
}
ml.promptsTemplates[modelFile] = tmpl
}
ml.models[modelFile] = model
return model, err
}