2023-04-11 22:02:39 +00:00
package api
2023-04-11 21:43:43 +00:00
import (
2023-04-20 16:33:02 +00:00
"encoding/json"
"errors"
2023-04-11 21:43:43 +00:00
"fmt"
"strings"
"sync"
2023-04-19 16:43:10 +00:00
model "github.com/go-skynet/LocalAI/pkg/model"
2023-04-20 17:33:36 +00:00
gpt2 "github.com/go-skynet/go-gpt2.cpp"
2023-04-19 15:10:29 +00:00
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
2023-04-11 21:43:43 +00:00
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/recover"
2023-04-21 22:44:52 +00:00
"github.com/rs/zerolog"
2023-04-20 16:33:02 +00:00
"github.com/rs/zerolog/log"
2023-04-11 21:43:43 +00:00
)
type OpenAIResponse struct {
Created int ` json:"created,omitempty" `
Object string ` json:"chat.completion,omitempty" `
ID string ` json:"id,omitempty" `
Model string ` json:"model,omitempty" `
Choices [ ] Choice ` json:"choices,omitempty" `
}
type Choice struct {
2023-04-16 08:16:48 +00:00
Index int ` json:"index,omitempty" `
FinishReason string ` json:"finish_reason,omitempty" `
Message * Message ` json:"message,omitempty" `
Text string ` json:"text,omitempty" `
2023-04-11 21:43:43 +00:00
}
type Message struct {
Role string ` json:"role,omitempty" `
Content string ` json:"content,omitempty" `
}
type OpenAIModel struct {
ID string ` json:"id" `
Object string ` json:"object" `
}
2023-04-11 22:02:39 +00:00
type OpenAIRequest struct {
Model string ` json:"model" `
// Prompt is read only by completion API calls
Prompt string ` json:"prompt" `
2023-04-16 08:16:48 +00:00
2023-04-21 17:46:59 +00:00
Stop string ` json:"stop" `
2023-04-13 13:20:51 +00:00
// Messages is read only by chat/completion API calls
2023-04-11 22:02:39 +00:00
Messages [ ] Message ` json:"messages" `
2023-04-16 08:16:48 +00:00
Echo bool ` json:"echo" `
2023-04-11 22:02:39 +00:00
// Common options between all the API calls
TopP float64 ` json:"top_p" `
TopK int ` json:"top_k" `
Temperature float64 ` json:"temperature" `
Maxtokens int ` json:"max_tokens" `
2023-04-16 08:16:48 +00:00
N int ` json:"n" `
// Custom parameters - not present in the OpenAI API
2023-04-21 17:46:59 +00:00
Batch int ` json:"batch" `
F16 bool ` json:"f16kv" `
IgnoreEOS bool ` json:"ignore_eos" `
RepeatPenalty float64 ` json:"repeat_penalty" `
Keep int ` json:"n_keep" `
2023-04-19 15:10:29 +00:00
Seed int ` json:"seed" `
2023-04-11 22:02:39 +00:00
}
2023-04-16 08:16:48 +00:00
// https://platform.openai.com/docs/api-reference/completions
2023-04-21 17:46:59 +00:00
func openAIEndpoint ( chat , debug bool , loader * model . ModelLoader , threads , ctx int , f16 bool , mutexMap * sync . Mutex , mutexes map [ string ] * sync . Mutex ) func ( c * fiber . Ctx ) error {
2023-04-11 21:43:43 +00:00
return func ( c * fiber . Ctx ) error {
var err error
var model * llama . LLama
2023-04-19 15:10:29 +00:00
var gptModel * gptj . GPTJ
2023-04-20 17:33:36 +00:00
var gpt2Model * gpt2 . GPT2
2023-04-20 22:06:55 +00:00
var stableLMModel * gpt2 . StableLM
2023-04-11 21:43:43 +00:00
2023-04-11 22:02:39 +00:00
input := new ( OpenAIRequest )
2023-04-11 21:43:43 +00:00
// Get input data from the request body
if err := c . BodyParser ( input ) ; err != nil {
return err
}
2023-04-20 16:33:02 +00:00
modelFile := input . Model
received , _ := json . Marshal ( input )
2023-04-11 21:43:43 +00:00
2023-04-20 16:33:02 +00:00
log . Debug ( ) . Msgf ( "Request received: %s" , string ( received ) )
// Set model from bearer token, if available
bearer := strings . TrimLeft ( c . Get ( "authorization" ) , "Bearer " )
bearerExists := bearer != "" && loader . ExistsInModelPath ( bearer )
2023-04-21 20:54:43 +00:00
// If no model was specified, take the first available
if modelFile == "" {
models , _ := loader . ListModels ( )
if len ( models ) > 0 {
modelFile = models [ 0 ]
log . Debug ( ) . Msgf ( "No model specified, using: %s" , modelFile )
}
}
// If no model is found or specified, we bail out
2023-04-20 16:33:02 +00:00
if modelFile == "" && ! bearerExists {
2023-04-16 08:40:50 +00:00
return fmt . Errorf ( "no model specified" )
2023-04-20 16:33:02 +00:00
}
2023-04-19 15:10:29 +00:00
2023-04-21 20:54:43 +00:00
// If a model is found in bearer token takes precedence
if bearerExists {
2023-04-20 16:33:02 +00:00
log . Debug ( ) . Msgf ( "Using model from bearer token: %s" , bearer )
modelFile = bearer
}
2023-04-21 20:54:43 +00:00
// Try to load the model
2023-04-20 22:06:55 +00:00
var llamaerr , gpt2err , gptjerr , stableerr error
2023-04-20 16:33:02 +00:00
llamaOpts := [ ] llama . ModelOption { }
if ctx != 0 {
llamaOpts = append ( llamaOpts , llama . SetContext ( ctx ) )
}
if f16 {
llamaOpts = append ( llamaOpts , llama . EnableF16Memory )
}
2023-04-20 17:33:36 +00:00
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
2023-04-20 16:33:02 +00:00
model , llamaerr = loader . LoadLLaMAModel ( modelFile , llamaOpts ... )
if llamaerr != nil {
2023-04-20 17:33:36 +00:00
gptModel , gptjerr = loader . LoadGPTJModel ( modelFile )
if gptjerr != nil {
gpt2Model , gpt2err = loader . LoadGPT2Model ( modelFile )
if gpt2err != nil {
2023-04-20 22:06:55 +00:00
stableLMModel , stableerr = loader . LoadStableLMModel ( modelFile )
if stableerr != nil {
return fmt . Errorf ( "llama: %s gpt: %s gpt2: %s stableLM: %s" , llamaerr . Error ( ) , gptjerr . Error ( ) , gpt2err . Error ( ) , stableerr . Error ( ) ) // llama failed first, so we want to catch both errors
}
2023-04-20 17:33:36 +00:00
}
2023-04-11 21:43:43 +00:00
}
}
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
2023-04-20 16:33:02 +00:00
mutexMap . Lock ( )
l , ok := mutexes [ modelFile ]
if ! ok {
m := & sync . Mutex { }
mutexes [ modelFile ] = m
l = m
2023-04-11 21:43:43 +00:00
}
2023-04-20 16:33:02 +00:00
mutexMap . Unlock ( )
l . Lock ( )
defer l . Unlock ( )
2023-04-11 21:43:43 +00:00
// Set the parameters for the language model prediction
2023-04-11 22:02:39 +00:00
topP := input . TopP
if topP == 0 {
topP = 0.7
2023-04-11 21:43:43 +00:00
}
2023-04-11 22:02:39 +00:00
topK := input . TopK
if topK == 0 {
topK = 80
2023-04-11 21:43:43 +00:00
}
2023-04-11 22:02:39 +00:00
temperature := input . Temperature
if temperature == 0 {
temperature = 0.9
2023-04-11 21:43:43 +00:00
}
2023-04-11 22:02:39 +00:00
tokens := input . Maxtokens
if tokens == 0 {
tokens = 512
2023-04-11 21:43:43 +00:00
}
predInput := input . Prompt
2023-04-11 22:02:39 +00:00
if chat {
mess := [ ] string { }
2023-04-20 16:33:02 +00:00
// TODO: encode roles
2023-04-11 22:02:39 +00:00
for _ , i := range input . Messages {
mess = append ( mess , i . Content )
2023-04-11 21:43:43 +00:00
}
2023-04-11 22:02:39 +00:00
predInput = strings . Join ( mess , "\n" )
2023-04-11 21:43:43 +00:00
}
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
2023-04-20 16:33:02 +00:00
templatedInput , err := loader . TemplatePrefix ( modelFile , struct {
2023-04-11 21:43:43 +00:00
Input string
} { Input : predInput } )
if err == nil {
predInput = templatedInput
2023-04-20 16:33:02 +00:00
log . Debug ( ) . Msgf ( "Template found, input modified to: %s" , predInput )
2023-04-11 21:43:43 +00:00
}
2023-04-16 08:16:48 +00:00
result := [ ] Choice { }
n := input . N
if input . N == 0 {
n = 1
2023-04-11 21:43:43 +00:00
}
2023-04-19 15:10:29 +00:00
var predFunc func ( ) ( string , error )
switch {
2023-04-20 22:06:55 +00:00
case stableLMModel != nil :
predFunc = func ( ) ( string , error ) {
// Generate the prediction using the language model
predictOptions := [ ] gpt2 . PredictOption {
gpt2 . SetTemperature ( temperature ) ,
gpt2 . SetTopP ( topP ) ,
gpt2 . SetTopK ( topK ) ,
gpt2 . SetTokens ( tokens ) ,
gpt2 . SetThreads ( threads ) ,
}
if input . Batch != 0 {
predictOptions = append ( predictOptions , gpt2 . SetBatch ( input . Batch ) )
}
if input . Seed != 0 {
predictOptions = append ( predictOptions , gpt2 . SetSeed ( input . Seed ) )
}
return stableLMModel . Predict (
predInput ,
predictOptions ... ,
)
}
2023-04-20 17:33:36 +00:00
case gpt2Model != nil :
predFunc = func ( ) ( string , error ) {
// Generate the prediction using the language model
predictOptions := [ ] gpt2 . PredictOption {
gpt2 . SetTemperature ( temperature ) ,
gpt2 . SetTopP ( topP ) ,
gpt2 . SetTopK ( topK ) ,
gpt2 . SetTokens ( tokens ) ,
gpt2 . SetThreads ( threads ) ,
}
if input . Batch != 0 {
predictOptions = append ( predictOptions , gpt2 . SetBatch ( input . Batch ) )
}
if input . Seed != 0 {
predictOptions = append ( predictOptions , gpt2 . SetSeed ( input . Seed ) )
}
return gpt2Model . Predict (
predInput ,
predictOptions ... ,
)
}
2023-04-19 15:10:29 +00:00
case gptModel != nil :
predFunc = func ( ) ( string , error ) {
// Generate the prediction using the language model
predictOptions := [ ] gptj . PredictOption {
gptj . SetTemperature ( temperature ) ,
gptj . SetTopP ( topP ) ,
gptj . SetTopK ( topK ) ,
gptj . SetTokens ( tokens ) ,
gptj . SetThreads ( threads ) ,
}
if input . Batch != 0 {
predictOptions = append ( predictOptions , gptj . SetBatch ( input . Batch ) )
}
if input . Seed != 0 {
predictOptions = append ( predictOptions , gptj . SetSeed ( input . Seed ) )
}
return gptModel . Predict (
predInput ,
predictOptions ... ,
)
2023-04-16 08:16:48 +00:00
}
2023-04-19 15:10:29 +00:00
case model != nil :
predFunc = func ( ) ( string , error ) {
// Generate the prediction using the language model
predictOptions := [ ] llama . PredictOption {
llama . SetTemperature ( temperature ) ,
llama . SetTopP ( topP ) ,
llama . SetTopK ( topK ) ,
llama . SetTokens ( tokens ) ,
llama . SetThreads ( threads ) ,
}
2023-04-21 17:46:59 +00:00
if debug {
predictOptions = append ( predictOptions , llama . Debug )
}
if input . Stop != "" {
predictOptions = append ( predictOptions , llama . SetStopWords ( input . Stop ) )
}
if input . RepeatPenalty != 0 {
predictOptions = append ( predictOptions , llama . SetPenalty ( input . RepeatPenalty ) )
}
if input . Keep != 0 {
predictOptions = append ( predictOptions , llama . SetNKeep ( input . Keep ) )
}
2023-04-19 15:10:29 +00:00
if input . Batch != 0 {
predictOptions = append ( predictOptions , llama . SetBatch ( input . Batch ) )
}
if input . F16 {
predictOptions = append ( predictOptions , llama . EnableF16KV )
}
if input . IgnoreEOS {
predictOptions = append ( predictOptions , llama . IgnoreEOS )
}
if input . Seed != 0 {
predictOptions = append ( predictOptions , llama . SetSeed ( input . Seed ) )
}
return model . Predict (
predInput ,
predictOptions ... ,
)
2023-04-16 08:16:48 +00:00
}
2023-04-19 15:10:29 +00:00
}
2023-04-16 08:16:48 +00:00
2023-04-19 15:10:29 +00:00
for i := 0 ; i < n ; i ++ {
prediction , err := predFunc ( )
2023-04-16 08:16:48 +00:00
if err != nil {
return err
}
if input . Echo {
prediction = predInput + prediction
}
2023-04-19 15:10:29 +00:00
2023-04-16 08:16:48 +00:00
if chat {
result = append ( result , Choice { Message : & Message { Role : "assistant" , Content : prediction } } )
} else {
result = append ( result , Choice { Text : prediction } )
}
2023-04-11 22:02:39 +00:00
}
2023-04-20 16:33:02 +00:00
jsonResult , _ := json . Marshal ( result )
log . Debug ( ) . Msgf ( "Response: %s" , jsonResult )
2023-04-11 21:43:43 +00:00
// Return the prediction in the response body
return c . JSON ( OpenAIResponse {
2023-04-20 16:33:02 +00:00
Model : input . Model , // we have to return what the user sent here, due to OpenAI spec.
2023-04-16 08:16:48 +00:00
Choices : result ,
2023-04-11 21:43:43 +00:00
} )
}
}
2023-04-20 16:33:02 +00:00
func listModels ( loader * model . ModelLoader ) func ( ctx * fiber . Ctx ) error {
return func ( c * fiber . Ctx ) error {
2023-04-11 21:43:43 +00:00
models , err := loader . ListModels ( )
if err != nil {
return err
}
dataModels := [ ] OpenAIModel { }
for _ , m := range models {
dataModels = append ( dataModels , OpenAIModel { ID : m , Object : "model" } )
}
return c . JSON ( struct {
Object string ` json:"object" `
Data [ ] OpenAIModel ` json:"data" `
} {
Object : "list" ,
Data : dataModels ,
} )
2023-04-20 16:33:02 +00:00
}
}
2023-04-21 22:44:52 +00:00
func App ( loader * model . ModelLoader , threads , ctxSize int , f16 bool , debug , disableMessage bool ) * fiber . App {
zerolog . SetGlobalLevel ( zerolog . InfoLevel )
if debug {
zerolog . SetGlobalLevel ( zerolog . DebugLevel )
}
2023-04-20 16:33:02 +00:00
// Return errors as JSON responses
app := fiber . New ( fiber . Config {
2023-04-21 22:44:52 +00:00
DisableStartupMessage : disableMessage ,
2023-04-20 16:33:02 +00:00
// Override default error handler
ErrorHandler : func ( ctx * fiber . Ctx , err error ) error {
// Status code defaults to 500
code := fiber . StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e * fiber . Error
if errors . As ( err , & e ) {
code = e . Code
}
// Send custom error page
return ctx . Status ( code ) . JSON ( struct {
Error string ` json:"error" `
} { Error : err . Error ( ) } )
} ,
2023-04-11 21:43:43 +00:00
} )
2023-04-20 16:33:02 +00:00
// Default middleware config
app . Use ( recover . New ( ) )
app . Use ( cors . New ( ) )
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mu := map [ string ] * sync . Mutex { }
var mumutex = & sync . Mutex { }
// openAI compatible API endpoint
2023-04-21 17:46:59 +00:00
app . Post ( "/v1/chat/completions" , openAIEndpoint ( true , debug , loader , threads , ctxSize , f16 , mumutex , mu ) )
app . Post ( "/chat/completions" , openAIEndpoint ( true , debug , loader , threads , ctxSize , f16 , mumutex , mu ) )
2023-04-20 16:33:02 +00:00
2023-04-21 17:46:59 +00:00
app . Post ( "/v1/completions" , openAIEndpoint ( false , debug , loader , threads , ctxSize , f16 , mumutex , mu ) )
app . Post ( "/completions" , openAIEndpoint ( false , debug , loader , threads , ctxSize , f16 , mumutex , mu ) )
2023-04-20 16:33:02 +00:00
app . Get ( "/v1/models" , listModels ( loader ) )
app . Get ( "/models" , listModels ( loader ) )
2023-04-21 22:44:52 +00:00
return app
2023-04-11 21:43:43 +00:00
}