feat: cancel stream generation if client disappears (#792)

This commit is contained in:
Aman Gupta Karmani 2023-07-24 14:10:54 -07:00 committed by GitHub
parent 72e3e236de
commit 12fe0932c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 37 additions and 21 deletions

View File

@ -1,6 +1,7 @@
package backend
import (
"context"
"os"
"regexp"
"strings"
@ -14,7 +15,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/utils"
)
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
modelFile := c.Model
grpcOpts := gRPCModelOpts(c)
@ -66,13 +67,13 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
opts.Prompt = s
if tokenCallback != nil {
ss := ""
err := inferenceModel.PredictStream(o.Context, opts, func(s string) {
err := inferenceModel.PredictStream(ctx, opts, func(s string) {
tokenCallback(s)
ss += s
})
return ss, err
} else {
reply, err := inferenceModel.Predict(o.Context, opts)
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return "", err
}

View File

@ -1,6 +1,7 @@
package openai
import (
"context"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grammar"
@ -70,6 +71,9 @@ type OpenAIModel struct {
type OpenAIRequest struct {
config.PredictionOptions
Context context.Context
Cancel context.CancelFunc
// whisper
File string `json:"file" validate:"required"`
//whisper/image

View File

@ -28,7 +28,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
responses <- initialMessage
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
@ -43,7 +43,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
modelFile, input, err := readInput(c, o.Loader, true)
modelFile, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -235,7 +235,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush()
}
@ -258,7 +263,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return nil
}
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
@ -300,7 +305,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil)
predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return

View File

@ -18,7 +18,7 @@ import (
// https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{
@ -38,7 +38,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
}
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, o.Loader, true)
modelFile, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -130,7 +130,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {

View File

@ -13,7 +13,7 @@ import (
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, o.Loader, true)
modelFile, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
@ -43,7 +43,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
log.Debug().Msgf("Template found, input modified to: %s", i)
}
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
}, nil)
if err != nil {

View File

@ -14,7 +14,7 @@ import (
// https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, o.Loader, true)
model, input, err := readInput(c, o, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -35,7 +35,7 @@ import (
*/
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, o.Loader, false)
m, input, err := readInput(c, o, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -7,7 +7,8 @@ import (
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
n := req.N
result := []Choice{}
if n == 0 {
@ -15,7 +16,7 @@ func ComputeChoices(predInput string, n int, config *config.Config, o *options.O
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback)
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
if err != nil {
return result, err
}

View File

@ -1,6 +1,7 @@
package openai
import (
"context"
"encoding/json"
"fmt"
"os"
@ -8,13 +9,18 @@ import (
"strings"
config "github.com/go-skynet/LocalAI/api/config"
options "github.com/go-skynet/LocalAI/api/options"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) {
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *OpenAIRequest, error) {
loader := o.Loader
input := new(OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", nil, err

View File

@ -19,7 +19,7 @@ import (
// https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, o.Loader, false)
m, input, err := readInput(c, o, false)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}

View File

@ -78,7 +78,7 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
return err
}
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)

View File

@ -102,7 +102,6 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Cl
// Check if we already have a loaded model
if model := ml.checkIsLoaded(modelName); model != nil {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return model, nil
}