diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 341dc34b..b2e7aa75 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -25,7 +25,7 @@ import ( // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error { - emptyMessage := "" + textContentToReturn := "" id := uuid.New().String() created := int(time.Now().Unix()) @@ -34,7 +34,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}}, Object: "chat.completion.chunk", } responses <- initialMessage @@ -69,6 +69,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup result = functions.CleanupLLMResult(result, config.FunctionsConfig) results := functions.ParseFunctionCall(result, config.FunctionsConfig) + textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig) noActionToRun := len(results) > 0 && results[0].Name == noAction || len(results) == 0 switch { @@ -77,7 +78,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}}, Object: "chat.completion.chunk", } responses <- initialMessage @@ -449,7 +450,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup { FinishReason: finishReason, Index: 0, - Delta: &schema.Message{Content: &emptyMessage}, + Delta: &schema.Message{Content: &textContentToReturn}, }}, Object: "chat.completion.chunk", Usage: *usage, @@ -473,6 +474,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup s = functions.CleanupLLMResult(s, config.FunctionsConfig) results := functions.ParseFunctionCall(s, config.FunctionsConfig) + textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig) noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0 switch { diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index d6e9d320..7bb3e6bd 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -3,6 +3,7 @@ package functions import ( "encoding/json" "regexp" + "strings" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/rs/zerolog/log" @@ -59,6 +60,11 @@ type FunctionsConfig struct { // ReplaceLLMResult allow to replace strings in the results before parsing them ReplaceLLMResult []ReplaceResult `yaml:"replace_llm_results"` + // CaptureLLMResult is a regex to extract a string from the LLM response + // that is used as return string when using tools. + // This is useful for e.g. if the LLM outputs a reasoning and we want to get the reasoning as a string back + CaptureLLMResult []string `yaml:"capture_llm_results"` + // FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } } // instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }. // This might be useful for certain models trained with the function name as the first token. @@ -109,6 +115,20 @@ func CleanupLLMResult(llmresult string, functionConfig FunctionsConfig) string { return llmresult } +func ParseTextContent(llmresult string, functionConfig FunctionsConfig) string { + for _, r := range functionConfig.CaptureLLMResult { + // We use a regex to extract the JSON object from the response + var respRegex = regexp.MustCompile(r) + match := respRegex.FindStringSubmatch(llmresult) + if len(match) >= 1 { + m := strings.TrimSpace(match[1]) + return m + } + } + + return "" +} + func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncCallResults { log.Debug().Msgf("LLM result: %s", llmresult) @@ -127,47 +147,52 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC } results := []FuncCallResults{} + llmResults := []string{} - returnResult := func(s string) (result []FuncCallResults, e error) { + returnResult := func(results []string) (result []FuncCallResults, e error) { // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) - var ss []map[string]interface{} result = make([]FuncCallResults, 0) - s = utils.EscapeNewLines(s) - err := json.Unmarshal([]byte(s), &ss) - if err != nil { - // If the LLM result is a single object, try unmarshaling it into a single map - var singleObj map[string]interface{} - err = json.Unmarshal([]byte(s), &singleObj) + + for _, s := range results { + var ss []map[string]interface{} + + s = utils.EscapeNewLines(s) + err := json.Unmarshal([]byte(s), &ss) if err != nil { - log.Debug().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result in a single object or an array of JSON objects") - } else { - ss = []map[string]interface{}{singleObj} - } - } - - log.Debug().Msgf("Function return: %s %+v", s, ss) - - for _, s := range ss { - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := s[functionNameKey] - if !ok { - continue - //return result, fmt.Errorf("unable to find function name in result") - } - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - if !ok { - continue - //return result, fmt.Errorf("unable to find arguments in result") - } - d, _ := json.Marshal(args) - funcName, ok := func_name.(string) - if !ok { - continue - //return result, fmt.Errorf("unable to cast function name to string") + // If the LLM result is a single object, try unmarshaling it into a single map + var singleObj map[string]interface{} + err = json.Unmarshal([]byte(s), &singleObj) + if err != nil { + log.Debug().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result in a single object or an array of JSON objects") + } else { + ss = []map[string]interface{}{singleObj} + } } - result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)}) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + for _, s := range ss { + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name, ok := s[functionNameKey] + if !ok { + continue + //return result, fmt.Errorf("unable to find function name in result") + } + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args, ok := s["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + if !ok { + continue + //return result, fmt.Errorf("unable to find arguments in result") + } + d, _ := json.Marshal(args) + funcName, ok := func_name.(string) + if !ok { + continue + //return result, fmt.Errorf("unable to cast function name to string") + } + + result = append(result, FuncCallResults{Name: funcName, Arguments: string(d)}) + } } return result, nil @@ -179,10 +204,16 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC for _, r := range functionConfig.JSONRegexMatch { // We use a regex to extract the JSON object from the response var respRegex = regexp.MustCompile(r) - match := respRegex.FindStringSubmatch(llmresult) - if len(match) >= 2 { - llmresult = match[1] - log.Debug().Msgf("LLM result(JSONRegexMatch): %s", llmresult) + match := respRegex.FindAllStringSubmatch(llmresult, -1) + var allMatches []string + for _, m := range match { + if len(m) > 1 { + // we match the first group + allMatches = append(allMatches, m[1]) + } + } + if len(allMatches) > 0 { + llmResults = append(llmResults, allMatches...) break } } @@ -193,22 +224,25 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC // obviously, this expects the LLM to be stable and return correctly formatted JSON // TODO: optimize this and pre-compile it var respRegex = regexp.MustCompile(functionConfig.ResponseRegex) - match := respRegex.FindStringSubmatch(llmresult) - for i, name := range respRegex.SubexpNames() { - if i != 0 && name != "" && len(match) > i { - result[name] = match[i] + matches := respRegex.FindAllStringSubmatch(llmresult, -1) + for _, match := range matches { + for i, name := range respRegex.SubexpNames() { + if i != 0 && name != "" && len(match) > i { + result[name] = match[i] + } } - } - // TODO: open point about multiple results and/or mixed with chat messages - // This is not handled as for now, we only expect one function call per response - functionName := result[functionNameKey] - if functionName == "" { - return results + functionName := result[functionNameKey] + if functionName == "" { + return results + } + results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]}) } - results = append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]}) } else { - results, _ = returnResult(llmresult) + if len(llmResults) == 0 { + llmResults = append(llmResults, llmresult) + } + results, _ = returnResult(llmResults) } return results diff --git a/pkg/functions/parse_test.go b/pkg/functions/parse_test.go index 5e266c50..01d8469f 100644 --- a/pkg/functions/parse_test.go +++ b/pkg/functions/parse_test.go @@ -215,5 +215,48 @@ Some text after the JSON Expect(results[0].Name).To(Equal("\"add\"")) Expect(results[0].Arguments).To(Equal(`{"x":5,"y":"v\"value\"","z":"\"v\""}`)) }) + + It("should detect multiple functions call where the JSONRegexMatch is repeated", func() { + input := ` +Some text before the JSON +{"function": "add", "arguments": {"x": 5, "y": 3}} +{"function": "subtract", "arguments": {"x": 10, "y": 7}} +Some text after the JSON +` + functionConfig.JSONRegexMatch = []string{`(?s)(.*?)`} + + results := ParseFunctionCall(input, functionConfig) + Expect(results).To(HaveLen(2)) + Expect(results[0].Name).To(Equal("add")) + Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`)) + Expect(results[1].Name).To(Equal("subtract")) + Expect(results[1].Arguments).To(Equal(`{"x":10,"y":7}`)) + }) + }) + Context("ParseTextContent", func() { + It("Can extract notes from the LLM result", func() { + input := ` + Some text before the JSON + +roses are red + + {"function": "subtract", "arguments": {"x": 10, "y": 7}} + Some text after the JSON + ` + functionConfig.CaptureLLMResult = []string{`(?s)(.*?)`} + results := ParseTextContent(input, functionConfig) + Expect(results).To(Equal("roses are red")) + }) + + It("Defaults to empty if doesn't catch any", func() { + input := ` + Some text before the JSON + {"function": "subtract", "arguments": {"x": 10, "y": 7}} + Some text after the JSON + ` + functionConfig.CaptureLLMResult = []string{`(?s)(.*?)`} + results := ParseTextContent(input, functionConfig) + Expect(results).To(Equal("")) + }) }) })