mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat: cancel stream generation if client disappears (#792)
This commit is contained in:
parent
72e3e236de
commit
12fe0932c4
@ -1,6 +1,7 @@
|
|||||||
package backend
|
package backend
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -14,7 +15,7 @@ import (
|
|||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||||
modelFile := c.Model
|
modelFile := c.Model
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c)
|
grpcOpts := gRPCModelOpts(c)
|
||||||
@ -66,13 +67,13 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
|
|||||||
opts.Prompt = s
|
opts.Prompt = s
|
||||||
if tokenCallback != nil {
|
if tokenCallback != nil {
|
||||||
ss := ""
|
ss := ""
|
||||||
err := inferenceModel.PredictStream(o.Context, opts, func(s string) {
|
err := inferenceModel.PredictStream(ctx, opts, func(s string) {
|
||||||
tokenCallback(s)
|
tokenCallback(s)
|
||||||
ss += s
|
ss += s
|
||||||
})
|
})
|
||||||
return ss, err
|
return ss, err
|
||||||
} else {
|
} else {
|
||||||
reply, err := inferenceModel.Predict(o.Context, opts)
|
reply, err := inferenceModel.Predict(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
@ -70,6 +71,9 @@ type OpenAIModel struct {
|
|||||||
type OpenAIRequest struct {
|
type OpenAIRequest struct {
|
||||||
config.PredictionOptions
|
config.PredictionOptions
|
||||||
|
|
||||||
|
Context context.Context
|
||||||
|
Cancel context.CancelFunc
|
||||||
|
|
||||||
// whisper
|
// whisper
|
||||||
File string `json:"file" validate:"required"`
|
File string `json:"file" validate:"required"`
|
||||||
//whisper/image
|
//whisper/image
|
||||||
|
@ -28,7 +28,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
}
|
}
|
||||||
responses <- initialMessage
|
responses <- initialMessage
|
||||||
|
|
||||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||||
resp := OpenAIResponse{
|
resp := OpenAIResponse{
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
|
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}},
|
||||||
@ -43,7 +43,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
processFunctions := false
|
processFunctions := false
|
||||||
funcs := grammar.Functions{}
|
funcs := grammar.Functions{}
|
||||||
modelFile, input, err := readInput(c, o.Loader, true)
|
modelFile, input, err := readInput(c, o, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -235,7 +235,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
enc.Encode(ev)
|
enc.Encode(ev)
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||||
|
input.Cancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
w.Flush()
|
w.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +263,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
|
result, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||||
if processFunctions {
|
if processFunctions {
|
||||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||||
ss := map[string]interface{}{}
|
ss := map[string]interface{}{}
|
||||||
@ -300,7 +305,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||||
// Note: This costs (in term of CPU) another computation
|
// Note: This costs (in term of CPU) another computation
|
||||||
config.Grammar = ""
|
config.Grammar = ""
|
||||||
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil)
|
predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("inference error: %s", err.Error())
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
return
|
return
|
||||||
|
@ -18,7 +18,7 @@ import (
|
|||||||
// https://platform.openai.com/docs/api-reference/completions
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
ComputeChoices(req, s, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||||
resp := OpenAIResponse{
|
resp := OpenAIResponse{
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
Choices: []Choice{
|
Choices: []Choice{
|
||||||
@ -38,7 +38,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readInput(c, o.Loader, true)
|
modelFile, input, err := readInput(c, o, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -130,7 +130,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
|
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||||
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
|
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
modelFile, input, err := readInput(c, o.Loader, true)
|
modelFile, input, err := readInput(c, o, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
@ -43,7 +43,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) {
|
r, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]Choice) {
|
||||||
*c = append(*c, Choice{Text: s})
|
*c = append(*c, Choice{Text: s})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
model, input, err := readInput(c, o.Loader, true)
|
model, input, err := readInput(c, o, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ import (
|
|||||||
*/
|
*/
|
||||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readInput(c, o.Loader, false)
|
m, input, err := readInput(c, o, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,8 @@ import (
|
|||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
func ComputeChoices(req *OpenAIRequest, predInput string, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
|
||||||
|
n := req.N
|
||||||
result := []Choice{}
|
result := []Choice{}
|
||||||
|
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
@ -15,7 +16,7 @@ func ComputeChoices(predInput string, n int, config *config.Config, o *options.O
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get the model function to call for the result
|
// get the model function to call for the result
|
||||||
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback)
|
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@ -8,13 +9,18 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
options "github.com/go-skynet/LocalAI/api/options"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) {
|
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *OpenAIRequest, error) {
|
||||||
|
loader := o.Loader
|
||||||
input := new(OpenAIRequest)
|
input := new(OpenAIRequest)
|
||||||
|
ctx, cancel := context.WithCancel(o.Context)
|
||||||
|
input.Context = ctx
|
||||||
|
input.Cancel = cancel
|
||||||
// Get input data from the request body
|
// Get input data from the request body
|
||||||
if err := c.BodyParser(input); err != nil {
|
if err := c.BodyParser(input); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
|
@ -19,7 +19,7 @@ import (
|
|||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
m, input, err := readInput(c, o.Loader, false)
|
m, input, err := readInput(c, o, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,7 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
|
log.Debug().Msgf("Loading GRPC Process: %s", grpcProcess)
|
||||||
|
|
||||||
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
|
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
|
||||||
|
|
||||||
|
@ -102,7 +102,6 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Cl
|
|||||||
|
|
||||||
// Check if we already have a loaded model
|
// Check if we already have a loaded model
|
||||||
if model := ml.checkIsLoaded(modelName); model != nil {
|
if model := ml.checkIsLoaded(modelName); model != nil {
|
||||||
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
|
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user