feat: cancel stream generation if client disappears (#792)

This commit is contained in:
Aman Gupta Karmani 2023-07-24 14:10:54 -07:00 committed by GitHub
parent 72e3e236de
commit 12fe0932c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 37 additions and 21 deletions

View File

@ -1,6 +1,7 @@
package backend package backend
import ( import (
"context"
"os" "os"
"regexp" "regexp"
"strings" "strings"
@ -14,7 +15,7 @@ import (
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
) )
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
modelFile := c.Model modelFile := c.Model
grpcOpts := gRPCModelOpts(c) grpcOpts := gRPCModelOpts(c)
@ -66,13 +67,13 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
opts.Prompt = s opts.Prompt = s
if tokenCallback != nil { if tokenCallback != nil {
ss := "" ss := ""
err := inferenceModel.PredictStream(o.Context, opts, func(s string) { err := inferenceModel.PredictStream(ctx, opts, func(s string) {
tokenCallback(s) tokenCallback(s)
ss += s ss += s
}) })
return ss, err return ss, err
} else { } else {
reply, err := inferenceModel.Predict(o.Context, opts) reply, err := inferenceModel.Predict(ctx, opts)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -1,6 +1,7 @@
package openai package openai
import ( import (
"context"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grammar"
@ -70,6 +71,9 @@ type OpenAIModel struct {
type OpenAIRequest struct { type OpenAIRequest struct {
config.PredictionOptions config.PredictionOptions
Context context.Context
Cancel context.CancelFunc
// whisper // whisper
File string `json:"file" validate:"required"` File string `json:"file" validate:"required"`
//whisper/image //whisper/image

View File

@ -28,7 +28,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
responses <- initialMessage responses <- initialMessage
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{ resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
@ -43,7 +43,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
processFunctions := false processFunctions := false
funcs := grammar.Functions{} funcs := grammar.Functions{}
modelFile, input, err := readInput(c, o.Loader, true) modelFile, input, err := readInput(c, o, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -235,7 +235,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
enc.Encode(ev) enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String()) log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String()) _, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
break
}
w.Flush() w.Flush()
} }
@ -258,7 +263,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return nil return nil
} }
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
if processFunctions { if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?) // As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{} ss := map[string]interface{}{}
@ -300,7 +305,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
// Otherwise ask the LLM to understand the JSON output and the context, and return a message // Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation // Note: This costs (in term of CPU) another computation
config.Grammar = "" config.Grammar = ""
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil)
if err != nil { if err != nil {
log.Error().Msgf("inference error: %s", err.Error()) log.Error().Msgf("inference error: %s", err.Error())
return return

View File

@ -18,7 +18,7 @@ import (
// https://platform.openai.com/docs/api-reference/completions // https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{ resp := OpenAIResponse{
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{ Choices: []Choice{
@ -38,7 +38,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
} }
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, o.Loader, true) modelFile, input, err := readInput(c, o, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -130,7 +130,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Template found, input modified to: %s", i) log.Debug().Msgf("Template found, input modified to: %s", i)
} }
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k}) *c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
}, nil) }, nil)
if err != nil { if err != nil {

View File

@ -13,7 +13,7 @@ import (
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
modelFile, input, err := readInput(c, o.Loader, true) modelFile, input, err := readInput(c, o, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -43,7 +43,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
log.Debug().Msgf("Template found, input modified to: %s", i) log.Debug().Msgf("Template found, input modified to: %s", i)
} }
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s}) *c = append(*c, Choice{Text: s})
}, nil) }, nil)
if err != nil { if err != nil {

View File

@ -14,7 +14,7 @@ import (
// https://platform.openai.com/docs/api-reference/embeddings // https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readInput(c, o.Loader, true) model, input, err := readInput(c, o, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }

View File

@ -35,7 +35,7 @@ import (
*/ */
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readInput(c, o.Loader, false) m, input, err := readInput(c, o, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }

View File

@ -7,7 +7,8 @@ import (
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
) )
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
n := req.N
result := []Choice{} result := []Choice{}
if n == 0 { if n == 0 {
@ -15,7 +16,7 @@ func ComputeChoices(predInput string, n int, config *config.Config, o *options.O
} }
// get the model function to call for the result // get the model function to call for the result
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
if err != nil { if err != nil {
return result, err return result, err
} }

View File

@ -1,6 +1,7 @@
package openai package openai
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -8,13 +9,18 @@ import (
"strings" "strings"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
options "github.com/go-skynet/LocalAI/api/options"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *OpenAIRequest, error) {
loader := o.Loader
input := new(OpenAIRequest) input := new(OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body // Get input data from the request body
if err := c.BodyParser(input); err != nil { if err := c.BodyParser(input); err != nil {
return "", nil, err return "", nil, err

View File

@ -19,7 +19,7 @@ import (
// https://platform.openai.com/docs/api-reference/audio/create // https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readInput(c, o.Loader, false) m, input, err := readInput(c, o, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }

View File

@ -78,7 +78,7 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
return err return err
} }
log.Debug().Msgf("Loading GRPC Process", grpcProcess) log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress) log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)

View File

@ -102,7 +102,6 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Cl
// Check if we already have a loaded model // Check if we already have a loaded model
if model := ml.checkIsLoaded(modelName); model != nil { if model := ml.checkIsLoaded(modelName); model != nil {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return model, nil return model, nil
} }