mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
fix: missing returning error and free callback stream (#187)
This commit is contained in:
parent
77ce8b953e
commit
714bfcd45b
@ -299,6 +299,21 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
|
||||
}
|
||||
|
||||
func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
|
||||
|
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||
resp := OpenAIResponse{
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant", Content: s}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
return func(c *fiber.Ctx) error {
|
||||
config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16)
|
||||
if err != nil {
|
||||
@ -350,19 +365,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
|
||||
if input.Stream {
|
||||
responses := make(chan OpenAIResponse)
|
||||
|
||||
go func() {
|
||||
ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
|
||||
resp := OpenAIResponse{
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant", Content: s}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}()
|
||||
go process(predInput, input, config, loader, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
|
@ -261,10 +261,15 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
|
||||
predictOptions = append(predictOptions, llama.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
return model.Predict(
|
||||
str, er := model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil)
|
||||
return str, er
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,10 +81,9 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string,
|
||||
if exists {
|
||||
m = t
|
||||
}
|
||||
|
||||
}
|
||||
if m == nil {
|
||||
return "", nil
|
||||
return "", fmt.Errorf("failed loading any template")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
Loading…
Reference in New Issue
Block a user