diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 9adba8ea..2b0b10a8 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -216,10 +216,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup } // Update input grammar - jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) + // Handle if we should return "name" instead of "functions" + if config.FunctionsConfig.FunctionName { + jsStruct := funcs.ToJSONNameStructure() + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) + } else { + jsStruct := funcs.ToJSONFunctionStructure() + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) + } case input.JSONFunctionGrammarObject != nil: config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) + case input.JSONFunctionGrammarObjectName != nil: + config.Grammar = input.JSONFunctionGrammarObjectName.Grammar("", config.FunctionsConfig.ParallelCalls) default: // Force picking one of the functions by the request if config.FunctionToCall() != "" { diff --git a/core/schema/openai.go b/core/schema/openai.go index a251ba68..177dc7ec 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -145,7 +145,8 @@ type OpenAIRequest struct { // A grammar to constrain the LLM output Grammar string `json:"grammar" yaml:"grammar"` - JSONFunctionGrammarObject *functions.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` + JSONFunctionGrammarObject *functions.JSONFunctionStructureFunction `json:"grammar_json_functions" yaml:"grammar_json_functions"` + JSONFunctionGrammarObjectName *functions.JSONFunctionStructureName `json:"grammar_json_name" yaml:"grammar_json_name"` Backend string `json:"backend" yaml:"backend"` diff --git a/pkg/functions/functions.go b/pkg/functions/functions.go index f5e37d75..f13ffe01 100644 --- a/pkg/functions/functions.go +++ b/pkg/functions/functions.go @@ -19,8 +19,10 @@ type Tool struct { } type Tools []Tool -func (f Functions) ToJSONStructure() JSONFunctionStructure { - js := JSONFunctionStructure{} +// ToJSONFunctionStructure converts a list of functions to a JSON structure that can be parsed to a grammar +// This allows the LLM to return a response of the type: { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } } +func (f Functions) ToJSONFunctionStructure() JSONFunctionStructureFunction { + js := JSONFunctionStructureFunction{} for _, function := range f { // t := function.Parameters["type"] //tt := t.(string) @@ -43,9 +45,49 @@ func (f Functions) ToJSONStructure() JSONFunctionStructure { if js.Defs == nil { js.Defs = defsD } - js.OneOf = append(js.OneOf, Item{ + js.OneOf = append(js.OneOf, ItemFunction{ Type: "object", - Properties: Properties{ + Properties: FunctionProperties{ + Function: FunctionName{Const: function.Name}, + Arguments: Argument{ + Type: "object", + Properties: prop, + }, + }, + }) + } + return js +} + +// ToJSONNameStructure converts a list of functions to a JSON structure that can be parsed to a grammar +// This allows the LLM to return a response of the type: { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } } +func (f Functions) ToJSONNameStructure() JSONFunctionStructureName { + js := JSONFunctionStructureName{} + for _, function := range f { + // t := function.Parameters["type"] + //tt := t.(string) + + properties := function.Parameters["properties"] + defs := function.Parameters["$defs"] + dat, _ := json.Marshal(properties) + dat2, _ := json.Marshal(defs) + prop := map[string]interface{}{} + defsD := map[string]interface{}{} + + err := json.Unmarshal(dat, &prop) + if err != nil { + log.Error().Err(err).Msg("error unmarshalling dat") + } + err = json.Unmarshal(dat2, &defsD) + if err != nil { + log.Error().Err(err).Msg("error unmarshalling dat2") + } + if js.Defs == nil { + js.Defs = defsD + } + js.OneOf = append(js.OneOf, ItemName{ + Type: "object", + Properties: NameProperties{ Function: FunctionName{Const: function.Name}, Arguments: Argument{ Type: "object", diff --git a/pkg/functions/functions_test.go b/pkg/functions/functions_test.go index 97953a5e..9bed86ec 100644 --- a/pkg/functions/functions_test.go +++ b/pkg/functions/functions_test.go @@ -35,13 +35,21 @@ var _ = Describe("LocalAI grammar functions", func() { }, } - js := functions.ToJSONStructure() + js := functions.ToJSONFunctionStructure() Expect(len(js.OneOf)).To(Equal(2)) Expect(js.OneOf[0].Properties.Function.Const).To(Equal("create_event")) Expect(js.OneOf[0].Properties.Arguments.Properties["event_name"].(map[string]interface{})["type"]).To(Equal("string")) Expect(js.OneOf[0].Properties.Arguments.Properties["event_date"].(map[string]interface{})["type"]).To(Equal("string")) Expect(js.OneOf[1].Properties.Function.Const).To(Equal("search")) Expect(js.OneOf[1].Properties.Arguments.Properties["query"].(map[string]interface{})["type"]).To(Equal("string")) + + jsN := functions.ToJSONNameStructure() + Expect(len(jsN.OneOf)).To(Equal(2)) + Expect(jsN.OneOf[0].Properties.Function.Const).To(Equal("create_event")) + Expect(jsN.OneOf[0].Properties.Arguments.Properties["event_name"].(map[string]interface{})["type"]).To(Equal("string")) + Expect(jsN.OneOf[0].Properties.Arguments.Properties["event_date"].(map[string]interface{})["type"]).To(Equal("string")) + Expect(jsN.OneOf[1].Properties.Function.Const).To(Equal("search")) + Expect(jsN.OneOf[1].Properties.Arguments.Properties["query"].(map[string]interface{})["type"]).To(Equal("string")) }) }) Context("Select()", func() { diff --git a/pkg/functions/grammar_json_schema.go b/pkg/functions/grammar_json_schema.go index 01046390..ede52fab 100644 --- a/pkg/functions/grammar_json_schema.go +++ b/pkg/functions/grammar_json_schema.go @@ -271,28 +271,49 @@ type FunctionName struct { Const string `json:"const"` } -type Properties struct { +type FunctionProperties struct { Function FunctionName `json:"function"` Arguments Argument `json:"arguments"` } +type NameProperties struct { + Function FunctionName `json:"name"` + Arguments Argument `json:"arguments"` +} + type Argument struct { Type string `json:"type"` Properties map[string]interface{} `json:"properties"` } -type Item struct { - Type string `json:"type"` - Properties Properties `json:"properties"` +type ItemName struct { + Type string `json:"type"` + Properties NameProperties `json:"properties"` } -type JSONFunctionStructure struct { - OneOf []Item `json:"oneOf,omitempty"` - AnyOf []Item `json:"anyOf,omitempty"` +type ItemFunction struct { + Type string `json:"type"` + Properties FunctionProperties `json:"properties"` +} + +type JSONFunctionStructureName struct { + OneOf []ItemName `json:"oneOf,omitempty"` + AnyOf []ItemName `json:"anyOf,omitempty"` Defs map[string]interface{} `json:"$defs,omitempty"` } -func (j JSONFunctionStructure) Grammar(propOrder string, maybeArray bool) string { +func (j JSONFunctionStructureName) Grammar(propOrder string, maybeArray bool) string { + dat, _ := json.Marshal(j) + return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) +} + +type JSONFunctionStructureFunction struct { + OneOf []ItemFunction `json:"oneOf,omitempty"` + AnyOf []ItemFunction `json:"anyOf,omitempty"` + Defs map[string]interface{} `json:"$defs,omitempty"` +} + +func (j JSONFunctionStructureFunction) Grammar(propOrder string, maybeArray bool) string { dat, _ := json.Marshal(j) return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) } diff --git a/pkg/functions/grammar_json_schema_test.go b/pkg/functions/grammar_json_schema_test.go index fc9029a8..83fae372 100644 --- a/pkg/functions/grammar_json_schema_test.go +++ b/pkg/functions/grammar_json_schema_test.go @@ -72,6 +72,70 @@ arr ::= (",\n" realvalue)* )? "]" root-1-function ::= "\"search\""` + + testInput2 = ` +{ + "oneOf": [ + { + "type": "object", + "properties": { + "name": {"const": "create_event"}, + "arguments": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "date": {"type": "string"}, + "time": {"type": "string"} + } + } + } + }, + { + "type": "object", + "properties": { + "name": {"const": "search"}, + "arguments": { + "type": "object", + "properties": { + "query": {"type": "string"} + } + } + } + } + ] +}` + + inputResult3 = `root-0-name ::= "\"create_event\"" +root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"name\"" space ":" space root-0-name "}" space +root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space +root ::= root-0 | root-1 +space ::= " "? +root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space +root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"name\"" space ":" space root-1-name "}" space +string ::= "\"" ( +[^"\\] | +"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) +)* "\"" space +root-1-name ::= "\"search\""` + + inputResult4 = `root-0-name ::= "\"create_event\"" +root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"name\"" space ":" space root-0-name "}" space +root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space +realvalue ::= root-0 | root-1 +root ::= arr | realvalue +space ::= " "? +root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space +root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"name\"" space ":" space root-1-name "}" space +string ::= "\"" ( +[^"\\] | +"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) +)* "\"" space +arr ::= +"[\n" ( + realvalue +(",\n" realvalue)* +)? "]" +root-1-name ::= "\"search\""` ) var _ = Describe("JSON schema grammar tests", func() { @@ -86,13 +150,23 @@ var _ = Describe("JSON schema grammar tests", func() { } Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) }) + It("generates a valid grammar from JSON schema", func() { + grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput2), false) + results := strings.Split(inputResult3, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) + }) It("generates a valid grammar from JSON Objects", func() { - structuredGrammar := JSONFunctionStructure{ - OneOf: []Item{ + structuredGrammar := JSONFunctionStructureFunction{ + OneOf: []ItemFunction{ { Type: "object", - Properties: Properties{ + Properties: FunctionProperties{ Function: FunctionName{ Const: "create_event", }, @@ -108,7 +182,7 @@ var _ = Describe("JSON schema grammar tests", func() { }, { Type: "object", - Properties: Properties{ + Properties: FunctionProperties{ Function: FunctionName{ Const: "search", }, @@ -133,11 +207,11 @@ var _ = Describe("JSON schema grammar tests", func() { }) It("generates a valid grammar from JSON Objects for multiple function return", func() { - structuredGrammar := JSONFunctionStructure{ - OneOf: []Item{ + structuredGrammar := JSONFunctionStructureFunction{ + OneOf: []ItemFunction{ { Type: "object", - Properties: Properties{ + Properties: FunctionProperties{ Function: FunctionName{ Const: "create_event", }, @@ -153,7 +227,7 @@ var _ = Describe("JSON schema grammar tests", func() { }, { Type: "object", - Properties: Properties{ + Properties: FunctionProperties{ Function: FunctionName{ Const: "search", }, @@ -176,5 +250,50 @@ var _ = Describe("JSON schema grammar tests", func() { } Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar) }) + + It("generates a valid grammar from JSON Objects for multiple function return", func() { + structuredGrammar := JSONFunctionStructureName{ + OneOf: []ItemName{ + { + Type: "object", + Properties: NameProperties{ + Function: FunctionName{ + Const: "create_event", + }, + Arguments: Argument{ // this is OpenAI's parameter + Type: "object", + Properties: map[string]interface{}{ + "title": map[string]string{"type": "string"}, + "date": map[string]string{"type": "string"}, + "time": map[string]string{"type": "string"}, + }, + }, + }, + }, + { + Type: "object", + Properties: NameProperties{ + Function: FunctionName{ + Const: "search", + }, + Arguments: Argument{ + Type: "object", + Properties: map[string]interface{}{ + "query": map[string]string{"type": "string"}, + }, + }, + }, + }, + }} + + grammar := structuredGrammar.Grammar("", true) + results := strings.Split(inputResult4, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar) + }) }) }) diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 26312560..13ac11e5 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -15,6 +15,11 @@ type FunctionsConfig struct { ParallelCalls bool `yaml:"parallel_calls"` NoGrammar bool `yaml:"no_grammar"` ResponseRegex string `yaml:"response_regex"` + + // FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } } + // instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }. + // This might be useful for certain models trained with the function name as the first token. + FunctionName bool `yaml:"return_name_in_function_response"` } type FuncCallResults struct { @@ -26,6 +31,11 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC multipleResults := functionConfig.ParallelCalls useGrammars := !functionConfig.NoGrammar + functionNameKey := "function" + if functionConfig.FunctionName { + functionNameKey = "name" + } + results := []FuncCallResults{} // if no grammar is used, we have to extract function and arguments from the result @@ -46,12 +56,12 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC // TODO: open point about multiple results and/or mixed with chat messages // This is not handled as for now, we only expect one function call per response - functionName := result["function"] + functionName := result[functionNameKey] if functionName == "" { return results } - return append(results, FuncCallResults{Name: result["function"], Arguments: result["arguments"]}) + return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]}) } // with grammars @@ -66,7 +76,7 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC log.Debug().Msgf("Function return: %s %+v", s, ss) for _, s := range ss { - func_name, ok := s["function"] + func_name, ok := s[functionNameKey] if !ok { continue } @@ -93,7 +103,7 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC log.Debug().Msgf("Function return: %s %+v", s, ss) // The grammar defines the function name as "function", while OpenAI returns "name" - func_name, ok := ss["function"] + func_name, ok := ss[functionNameKey] if !ok { return results }