From 12fe0932c41246914e455c4175269a431fb8cf60 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Mon, 24 Jul 2023 14:10:54 -0700 Subject: [PATCH] feat: cancel stream generation if client disappears (#792) --- api/backend/llm.go | 7 ++++--- api/openai/api.go | 4 ++++ api/openai/chat.go | 15 ++++++++++----- api/openai/completion.go | 6 +++--- api/openai/edit.go | 4 ++-- api/openai/embeddings.go | 2 +- api/openai/image.go | 2 +- api/openai/inference.go | 5 +++-- api/openai/request.go | 8 +++++++- api/openai/transcription.go | 2 +- pkg/model/initializers.go | 2 +- pkg/model/loader.go | 1 - 12 files changed, 37 insertions(+), 21 deletions(-) diff --git a/api/backend/llm.go b/api/backend/llm.go index 23a5ca4c..c9a5e798 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -1,6 +1,7 @@ package backend import ( + "context" "os" "regexp" "strings" @@ -14,7 +15,7 @@ import ( "github.com/go-skynet/LocalAI/pkg/utils" ) -func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { +func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { modelFile := c.Model grpcOpts := gRPCModelOpts(c) @@ -66,13 +67,13 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt opts.Prompt = s if tokenCallback != nil { ss := "" - err := inferenceModel.PredictStream(o.Context, opts, func(s string) { + err := inferenceModel.PredictStream(ctx, opts, func(s string) { tokenCallback(s) ss += s }) return ss, err } else { - reply, err := inferenceModel.Predict(o.Context, opts) + reply, err := inferenceModel.Predict(ctx, opts) if err != nil { return "", err } diff --git a/api/openai/api.go b/api/openai/api.go index bf2b6394..ac757f47 100644 --- a/api/openai/api.go +++ b/api/openai/api.go @@ -1,6 +1,7 @@ package openai import ( + "context" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/grammar" @@ -70,6 +71,9 @@ type OpenAIModel struct { type OpenAIRequest struct { config.PredictionOptions + Context context.Context + Cancel context.CancelFunc + // whisper File string `json:"file" validate:"required"` //whisper/image diff --git a/api/openai/chat.go b/api/openai/chat.go index a9cbd240..0c603886 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -28,7 +28,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } responses <- initialMessage - ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { resp := OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, @@ -43,7 +43,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return func(c *fiber.Ctx) error { processFunctions := false funcs := grammar.Functions{} - modelFile, input, err := readInput(c, o.Loader, true) + modelFile, input, err := readInput(c, o, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -235,7 +235,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) enc.Encode(ev) log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) + _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) + if err != nil { + log.Debug().Msgf("Sending chunk failed: %v", err) + input.Cancel() + break + } w.Flush() } @@ -258,7 +263,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return nil } - result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) { if processFunctions { // As we have to change the result before processing, we can't stream the answer (yet?) ss := map[string]interface{}{} @@ -300,7 +305,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) // Otherwise ask the LLM to understand the JSON output and the context, and return a message // Note: This costs (in term of CPU) another computation config.Grammar = "" - predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) + predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil) if err != nil { log.Error().Msgf("inference error: %s", err.Error()) return diff --git a/api/openai/completion.go b/api/openai/completion.go index 1efe37c7..e7406ebb 100644 --- a/api/openai/completion.go +++ b/api/openai/completion.go @@ -18,7 +18,7 @@ import ( // https://platform.openai.com/docs/api-reference/completions func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { resp := OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{ @@ -38,7 +38,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe } return func(c *fiber.Ctx) error { - modelFile, input, err := readInput(c, o.Loader, true) + modelFile, input, err := readInput(c, o, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -130,7 +130,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Template found, input modified to: %s", i) } - r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k}) }, nil) if err != nil { diff --git a/api/openai/edit.go b/api/openai/edit.go index 459c9748..d5a7f279 100644 --- a/api/openai/edit.go +++ b/api/openai/edit.go @@ -13,7 +13,7 @@ import ( func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - modelFile, input, err := readInput(c, o.Loader, true) + modelFile, input, err := readInput(c, o, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -43,7 +43,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) log.Debug().Msgf("Template found, input modified to: %s", i) } - r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s}) }, nil) if err != nil { diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go index 248ae5cf..7d47060d 100644 --- a/api/openai/embeddings.go +++ b/api/openai/embeddings.go @@ -14,7 +14,7 @@ import ( // https://platform.openai.com/docs/api-reference/embeddings func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.Loader, true) + model, input, err := readInput(c, o, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/api/openai/image.go b/api/openai/image.go index bca54c16..8d0e7b8a 100644 --- a/api/openai/image.go +++ b/api/openai/image.go @@ -35,7 +35,7 @@ import ( */ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.Loader, false) + m, input, err := readInput(c, o, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/api/openai/inference.go b/api/openai/inference.go index a9991fa0..68d7ae85 100644 --- a/api/openai/inference.go +++ b/api/openai/inference.go @@ -7,7 +7,8 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { +func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { + n := req.N result := []Choice{} if n == 0 { @@ -15,7 +16,7 @@ func ComputeChoices(predInput string, n int, config *config.Config, o *options.O } // get the model function to call for the result - predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) + predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback) if err != nil { return result, err } diff --git a/api/openai/request.go b/api/openai/request.go index 84dbaa8e..03c4b806 100644 --- a/api/openai/request.go +++ b/api/openai/request.go @@ -1,6 +1,7 @@ package openai import ( + "context" "encoding/json" "fmt" "os" @@ -8,13 +9,18 @@ import ( "strings" config "github.com/go-skynet/LocalAI/api/config" + options "github.com/go-skynet/LocalAI/api/options" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) -func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { +func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *OpenAIRequest, error) { + loader := o.Loader input := new(OpenAIRequest) + ctx, cancel := context.WithCancel(o.Context) + input.Context = ctx + input.Cancel = cancel // Get input data from the request body if err := c.BodyParser(input); err != nil { return "", nil, err diff --git a/api/openai/transcription.go b/api/openai/transcription.go index 4b4a65e0..895c110f 100644 --- a/api/openai/transcription.go +++ b/api/openai/transcription.go @@ -19,7 +19,7 @@ import ( // https://platform.openai.com/docs/api-reference/audio/create func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.Loader, false) + m, input, err := readInput(c, o, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 86ee554d..22bee483 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -78,7 +78,7 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string return err } - log.Debug().Msgf("Loading GRPC Process", grpcProcess) + log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess) log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress) diff --git a/pkg/model/loader.go b/pkg/model/loader.go index bb49a7cc..daadc969 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -102,7 +102,6 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Cl // Check if we already have a loaded model if model := ml.checkIsLoaded(modelName); model != nil { - log.Debug().Msgf("Model already loaded in memory: %s", modelName) return model, nil }