diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go
index 341dc34b..b2e7aa75 100644
--- a/core/http/endpoints/openai/chat.go
+++ b/core/http/endpoints/openai/chat.go
@@ -25,7 +25,7 @@ import (
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
- emptyMessage := ""
+ textContentToReturn := ""
id := uuid.New().String()
created := int(time.Now().Unix())
@@ -34,7 +34,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
- Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
+ Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
@@ -69,6 +69,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
results := functions.ParseFunctionCall(result, config.FunctionsConfig)
+ textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0
switch {
@@ -77,7 +78,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
- Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
+ Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
@@ -449,7 +450,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
{
FinishReason: finishReason,
Index: 0,
- Delta: &schema.Message{Content: &emptyMessage},
+ Delta: &schema.Message{Content: &textContentToReturn},
}},
Object: "chat.completion.chunk",
Usage: *usage,
@@ -473,6 +474,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
+ textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
switch {
diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go
index d6e9d320..7bb3e6bd 100644
--- a/pkg/functions/parse.go
+++ b/pkg/functions/parse.go
@@ -3,6 +3,7 @@ package functions
import (
"encoding/json"
"regexp"
+ "strings"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
@@ -59,6 +60,11 @@ type FunctionsConfig struct {
// ReplaceLLMResult allow to replace strings in the results before parsing them
ReplaceLLMResult []ReplaceResult `yaml:"replace_llm_results"`
+ // CaptureLLMResult is a regex to extract a string from the LLM response
+ // that is used as return string when using tools.
+ // This is useful for e.g. if the LLM outputs a reasoning and we want to get the reasoning as a string back
+ CaptureLLMResult []string `yaml:"capture_llm_results"`
+
// FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }
// instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }.
// This might be useful for certain models trained with the function name as the first token.
@@ -109,6 +115,20 @@ func CleanupLLMResult(llmresult string, functionConfig FunctionsConfig) string {
return llmresult
}
+func ParseTextContent(llmresult string, functionConfig FunctionsConfig) string {
+ for _, r := range functionConfig.CaptureLLMResult {
+ // We use a regex to extract the JSON object from the response
+ var respRegex = regexp.MustCompile(r)
+ match := respRegex.FindStringSubmatch(llmresult)
+ if len(match) >= 1 {
+ m := strings.TrimSpace(match[1])
+ return m
+ }
+ }
+
+ return ""
+}
+
func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncCallResults {
log.Debug().Msgf("LLM result: %s", llmresult)
@@ -127,47 +147,52 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
}
results := []FuncCallResults{}
+ llmResults := []string{}
- returnResult := func(s string) (result []FuncCallResults, e error) {
+ returnResult := func(results []string) (result []FuncCallResults, e error) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
- var ss []map[string]interface{}
result = make([]FuncCallResults, 0)
- s = utils.EscapeNewLines(s)
- err := json.Unmarshal([]byte(s), &ss)
- if err != nil {
- // If the LLM result is a single object, try unmarshaling it into a single map
- var singleObj map[string]interface{}
- err = json.Unmarshal([]byte(s), &singleObj)
+
+ for _, s := range results {
+ var ss []map[string]interface{}
+
+ s = utils.EscapeNewLines(s)
+ err := json.Unmarshal([]byte(s), &ss)
if err != nil {
- log.Debug().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result in a single object or an array of JSON objects")
- } else {
- ss = []map[string]interface{}{singleObj}
- }
- }
-
- log.Debug().Msgf("Function return: %s %+v", s, ss)
-
- for _, s := range ss {
- // The grammar defines the function name as "function", while OpenAI returns "name"
- func_name, ok := s[functionNameKey]
- if !ok {
- continue
- //return result, fmt.Errorf("unable to find function name in result")
- }
- // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
- args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
- if !ok {
- continue
- //return result, fmt.Errorf("unable to find arguments in result")
- }
- d, _ := json.Marshal(args)
- funcName, ok := func_name.(string)
- if !ok {
- continue
- //return result, fmt.Errorf("unable to cast function name to string")
+ // If the LLM result is a single object, try unmarshaling it into a single map
+ var singleObj map[string]interface{}
+ err = json.Unmarshal([]byte(s), &singleObj)
+ if err != nil {
+ log.Debug().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result in a single object or an array of JSON objects")
+ } else {
+ ss = []map[string]interface{}{singleObj}
+ }
}
- result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)})
+ log.Debug().Msgf("Function return: %s %+v", s, ss)
+
+ for _, s := range ss {
+ // The grammar defines the function name as "function", while OpenAI returns "name"
+ func_name, ok := s[functionNameKey]
+ if !ok {
+ continue
+ //return result, fmt.Errorf("unable to find function name in result")
+ }
+ // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
+ args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
+ if !ok {
+ continue
+ //return result, fmt.Errorf("unable to find arguments in result")
+ }
+ d, _ := json.Marshal(args)
+ funcName, ok := func_name.(string)
+ if !ok {
+ continue
+ //return result, fmt.Errorf("unable to cast function name to string")
+ }
+
+ result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)})
+ }
}
return result, nil
@@ -179,10 +204,16 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
for _, r := range functionConfig.JSONRegexMatch {
// We use a regex to extract the JSON object from the response
var respRegex = regexp.MustCompile(r)
- match := respRegex.FindStringSubmatch(llmresult)
- if len(match) >= 2 {
- llmresult = match[1]
- log.Debug().Msgf("LLM result(JSONRegexMatch): %s", llmresult)
+ match := respRegex.FindAllStringSubmatch(llmresult, -1)
+ var allMatches []string
+ for _, m := range match {
+ if len(m) > 1 {
+ // we match the first group
+ allMatches = append(allMatches, m[1])
+ }
+ }
+ if len(allMatches) > 0 {
+ llmResults = append(llmResults, allMatches...)
break
}
}
@@ -193,22 +224,25 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
// obviously, this expects the LLM to be stable and return correctly formatted JSON
// TODO: optimize this and pre-compile it
var respRegex = regexp.MustCompile(functionConfig.ResponseRegex)
- match := respRegex.FindStringSubmatch(llmresult)
- for i, name := range respRegex.SubexpNames() {
- if i != 0 && name != "" && len(match) > i {
- result[name] = match[i]
+ matches := respRegex.FindAllStringSubmatch(llmresult, -1)
+ for _, match := range matches {
+ for i, name := range respRegex.SubexpNames() {
+ if i != 0 && name != "" && len(match) > i {
+ result[name] = match[i]
+ }
}
- }
- // TODO: open point about multiple results and/or mixed with chat messages
- // This is not handled as for now, we only expect one function call per response
- functionName := result[functionNameKey]
- if functionName == "" {
- return results
+ functionName := result[functionNameKey]
+ if functionName == "" {
+ return results
+ }
+ results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
}
- results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
} else {
- results, _ = returnResult(llmresult)
+ if len(llmResults) == 0 {
+ llmResults = append(llmResults, llmresult)
+ }
+ results, _ = returnResult(llmResults)
}
return results
diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go
index 5e266c50..01d8469f 100644
--- a/pkg/functions/parse_test.go
+++ b/pkg/functions/parse_test.go
@@ -215,5 +215,48 @@ Some text after the JSON
Expect(results[0].Name).To(Equal("\"add\""))
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":"v\"value\"","z":"\"v\""}`))
})
+
+ It("should detect multiple functions call where the JSONRegexMatch is repeated", func() {
+ input := `
+Some text before the JSON
+{"function": "add", "arguments": {"x": 5, "y": 3}}
+{"function": "subtract", "arguments": {"x": 10, "y": 7}}
+Some text after the JSON
+`
+ functionConfig.JSONRegexMatch = []string{`(?s)(.*?)`}
+
+ results := ParseFunctionCall(input, functionConfig)
+ Expect(results).To(HaveLen(2))
+ Expect(results[0].Name).To(Equal("add"))
+ Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
+ Expect(results[1].Name).To(Equal("subtract"))
+ Expect(results[1].Arguments).To(Equal(`{"x":10,"y":7}`))
+ })
+ })
+ Context("ParseTextContent", func() {
+ It("Can extract notes from the LLM result", func() {
+ input := `
+ Some text before the JSON
+
+roses are red
+
+ {"function": "subtract", "arguments": {"x": 10, "y": 7}}
+ Some text after the JSON
+ `
+ functionConfig.CaptureLLMResult = []string{`(?s)(.*?)`}
+ results := ParseTextContent(input, functionConfig)
+ Expect(results).To(Equal("roses are red"))
+ })
+
+ It("Defaults to empty if doesn't catch any", func() {
+ input := `
+ Some text before the JSON
+ {"function": "subtract", "arguments": {"x": 10, "y": 7}}
+ Some text after the JSON
+ `
+ functionConfig.CaptureLLMResult = []string{`(?s)(.*?)`}
+ results := ParseTextContent(input, functionConfig)
+ Expect(results).To(Equal(""))
+ })
})
})