Make it compatible with openAI api, support multiple models

Signed-off-by: mudler <mudler@c3os.io>
This commit is contained in:
mudler 2023-04-07 11:30:59 +02:00
parent b33d015b8c
commit 12eee097b7
3 changed files with 176 additions and 12 deletions

112
api.go
View File

@ -2,24 +2,128 @@ package main
import (
"embed"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
)
type OpenAIResponse struct {
Created int `json:"created"`
Object string `json:"chat.completion"`
ID string `json:"id"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
}
type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Message Message `json:"message"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
//go:embed index.html
var indexHTML embed.FS
func api(l *llama.LLama, listenAddr string, threads int) error {
func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, threads int) error {
app := fiber.New()
// Default middleware config
app.Use(recover.New())
app.Use(cors.New())
app.Use("/", filesystem.New(filesystem.Config{
Root: http.FS(indexHTML),
NotFoundFile: "index.html",
}))
var mutex = &sync.Mutex{}
// openAI compatible API endpoint
app.Post("/v1/chat/completions", func(c *fiber.Ctx) error {
var err error
var model *llama.LLama
// Get input data from the request body
input := new(struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
})
if err := c.BodyParser(input); err != nil {
return err
}
if input.Model == "" {
if defaultModel == nil {
return fmt.Errorf("no default model loaded, and no model specified")
}
model = defaultModel
} else {
model, err = loader.LoadModel(input.Model)
if err != nil {
return err
}
}
// Set the parameters for the language model prediction
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
if err != nil {
return err
}
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
if err != nil {
return err
}
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
if err != nil {
return err
}
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
if err != nil {
return err
}
mess := []string{}
for _, i := range input.Messages {
mess = append(mess, i.Content)
}
fmt.Println("Received", input, input.Model)
// Generate the prediction using the language model
prediction, err := model.Predict(
strings.Join(mess, "\n"),
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
)
if err != nil {
return err
}
// Return the prediction in the response body
return c.JSON(OpenAIResponse{
Model: input.Model,
Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}},
})
})
/*
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
"text": "What is an alpaca?",
@ -29,8 +133,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
"tokens": 100
}'
*/
var mutex = &sync.Mutex{}
// Endpoint to generate the prediction
app.Post("/predict", func(c *fiber.Ctx) error {
mutex.Lock()
@ -65,7 +167,7 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
}
// Generate the prediction using the language model
prediction, err := l.Predict(
prediction, err := defaultModel.Predict(
input.Text,
llama.SetTemperature(temperature),
llama.SetTopP(topP),
@ -86,6 +188,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error {
})
// Start the server
app.Listen(":8080")
app.Listen(listenAddr)
return nil
}

24
main.go
View File

@ -146,8 +146,12 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
Value: runtime.NumCPU(),
},
&cli.StringFlag{
Name: "model",
EnvVars: []string{"MODEL_PATH"},
Name: "models-path",
EnvVars: []string{"MODELS_PATH"},
},
&cli.StringFlag{
Name: "default-model",
EnvVars: []string{"default-model"},
},
&cli.StringFlag{
Name: "address",
@ -161,13 +165,19 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
},
},
Action: func(ctx *cli.Context) error {
l, err := llamaFromOptions(ctx)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
var defaultModel *llama.LLama
defModel := ctx.String("default-model")
if defModel != "" {
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
var err error
defaultModel, err = llama.New(ctx.String("default-model"), opts...)
if err != nil {
return err
}
}
return api(l, ctx.String("address"), ctx.Int("threads"))
return api(defaultModel, NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
},
},
},

52
model_loader.go Normal file
View File

@ -0,0 +1,52 @@
package main
import (
"fmt"
"os"
"path/filepath"
"sync"
llama "github.com/go-skynet/go-llama.cpp"
)
type ModelLoader struct {
modelPath string
mu sync.Mutex
models map[string]*llama.LLama
}
func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama)}
}
func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LLama, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
modelFile := filepath.Join(ml.modelPath, s)
if m, ok := ml.models[modelFile]; ok {
return m, nil
}
// Check if the model path exists
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
// try to find a s.bin
modelBin := fmt.Sprintf("%s.bin", modelFile)
if _, err := os.Stat(modelBin); os.IsNotExist(err) {
return nil, err
} else {
modelFile = modelBin
}
}
// Load the model and keep it in memory for later use
model, err := llama.New(modelFile, opts...)
if err != nil {
return nil, err
}
ml.models[modelFile] = model
return model, err
}