mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
255748bcba
This PR specifically introduces a `core` folder and moves the following packages over, without any other changes: - `api/backend` - `api/config` - `api/options` - `api/schema` Once this is merged and we confirm there's no regressions, I can migrate over the remaining changes piece by piece to split up application startup, backend services, http, and mqtt as was the goal of the earlier PRs!
200 lines
5.6 KiB
Go
200 lines
5.6 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/go-skynet/LocalAI/core/backend"
|
|
config "github.com/go-skynet/LocalAI/core/config"
|
|
"github.com/go-skynet/LocalAI/core/options"
|
|
"github.com/go-skynet/LocalAI/core/schema"
|
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// https://platform.openai.com/docs/api-reference/completions
|
|
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
|
id := uuid.New().String()
|
|
created := int(time.Now().Unix())
|
|
|
|
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
|
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
|
resp := schema.OpenAIResponse{
|
|
ID: id,
|
|
Created: created,
|
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
Choices: []schema.Choice{
|
|
{
|
|
Index: 0,
|
|
Text: s,
|
|
},
|
|
},
|
|
Object: "text_completion",
|
|
Usage: schema.OpenAIUsage{
|
|
PromptTokens: usage.Prompt,
|
|
CompletionTokens: usage.Completion,
|
|
TotalTokens: usage.Prompt + usage.Completion,
|
|
},
|
|
}
|
|
log.Debug().Msgf("Sending goroutine: %s", s)
|
|
|
|
responses <- resp
|
|
return true
|
|
})
|
|
close(responses)
|
|
}
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
modelFile, input, err := readRequest(c, o, true)
|
|
if err != nil {
|
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
}
|
|
|
|
log.Debug().Msgf("`input`: %+v", input)
|
|
|
|
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
|
if err != nil {
|
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
}
|
|
|
|
if input.ResponseFormat.Type == "json_object" {
|
|
input.Grammar = grammar.JSONBNF
|
|
}
|
|
|
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
|
|
if input.Stream {
|
|
log.Debug().Msgf("Stream request received")
|
|
c.Context().SetContentType("text/event-stream")
|
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
//c.Set("Content-Type", "text/event-stream")
|
|
c.Set("Cache-Control", "no-cache")
|
|
c.Set("Connection", "keep-alive")
|
|
c.Set("Transfer-Encoding", "chunked")
|
|
}
|
|
|
|
templateFile := ""
|
|
|
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
|
templateFile = config.Model
|
|
}
|
|
|
|
if config.TemplateConfig.Completion != "" {
|
|
templateFile = config.TemplateConfig.Completion
|
|
}
|
|
|
|
if input.Stream {
|
|
if len(config.PromptStrings) > 1 {
|
|
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
|
}
|
|
|
|
predInput := config.PromptStrings[0]
|
|
|
|
if templateFile != "" {
|
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
|
Input: predInput,
|
|
})
|
|
if err == nil {
|
|
predInput = templatedInput
|
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
|
}
|
|
}
|
|
|
|
responses := make(chan schema.OpenAIResponse)
|
|
|
|
go process(predInput, input, config, o.Loader, responses)
|
|
|
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
|
|
|
for ev := range responses {
|
|
var buf bytes.Buffer
|
|
enc := json.NewEncoder(&buf)
|
|
enc.Encode(ev)
|
|
|
|
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
|
fmt.Fprintf(w, "data: %v\n", buf.String())
|
|
w.Flush()
|
|
}
|
|
|
|
resp := &schema.OpenAIResponse{
|
|
ID: id,
|
|
Created: created,
|
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
Choices: []schema.Choice{
|
|
{
|
|
Index: 0,
|
|
FinishReason: "stop",
|
|
},
|
|
},
|
|
Object: "text_completion",
|
|
}
|
|
respData, _ := json.Marshal(resp)
|
|
|
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
|
w.WriteString("data: [DONE]\n\n")
|
|
w.Flush()
|
|
}))
|
|
return nil
|
|
}
|
|
|
|
var result []schema.Choice
|
|
|
|
totalTokenUsage := backend.TokenUsage{}
|
|
|
|
for k, i := range config.PromptStrings {
|
|
if templateFile != "" {
|
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
|
SystemPrompt: config.SystemPrompt,
|
|
Input: i,
|
|
})
|
|
if err == nil {
|
|
i = templatedInput
|
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
|
}
|
|
}
|
|
|
|
r, tokenUsage, err := ComputeChoices(
|
|
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
|
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
|
}, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
totalTokenUsage.Prompt += tokenUsage.Prompt
|
|
totalTokenUsage.Completion += tokenUsage.Completion
|
|
|
|
result = append(result, r...)
|
|
}
|
|
|
|
resp := &schema.OpenAIResponse{
|
|
ID: id,
|
|
Created: created,
|
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
Choices: result,
|
|
Object: "text_completion",
|
|
Usage: schema.OpenAIUsage{
|
|
PromptTokens: totalTokenUsage.Prompt,
|
|
CompletionTokens: totalTokenUsage.Completion,
|
|
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
|
},
|
|
}
|
|
|
|
jsonResult, _ := json.Marshal(resp)
|
|
log.Debug().Msgf("Response: %s", jsonResult)
|
|
|
|
// Return the prediction in the response body
|
|
return c.JSON(resp)
|
|
}
|
|
}
|