mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat: add grammar and functions call support
This commit is contained in:
parent
a6839fd238
commit
f09ddd2983
@ -46,12 +46,16 @@ type Config struct {
|
|||||||
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
PromptCacheAll bool `yaml:"prompt_cache_all"`
|
||||||
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
PromptCacheRO bool `yaml:"prompt_cache_ro"`
|
||||||
|
|
||||||
PromptStrings, InputStrings []string
|
Grammar string `yaml:"grammar"`
|
||||||
InputToken [][]int
|
|
||||||
|
PromptStrings, InputStrings []string
|
||||||
|
InputToken [][]int
|
||||||
|
functionCallString, functionCallNameString string
|
||||||
}
|
}
|
||||||
|
|
||||||
type TemplateConfig struct {
|
type TemplateConfig struct {
|
||||||
Completion string `yaml:"completion"`
|
Completion string `yaml:"completion"`
|
||||||
|
Functions string `yaml:"function"`
|
||||||
Chat string `yaml:"chat"`
|
Chat string `yaml:"chat"`
|
||||||
Edit string `yaml:"edit"`
|
Edit string `yaml:"edit"`
|
||||||
}
|
}
|
||||||
@ -181,6 +185,10 @@ func updateConfig(config *Config, input *OpenAIRequest) {
|
|||||||
config.TopP = input.TopP
|
config.TopP = input.TopP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.Grammar != "" {
|
||||||
|
config.Grammar = input.Grammar
|
||||||
|
}
|
||||||
|
|
||||||
if input.Temperature != 0 {
|
if input.Temperature != 0 {
|
||||||
config.Temperature = input.Temperature
|
config.Temperature = input.Temperature
|
||||||
}
|
}
|
||||||
@ -262,6 +270,24 @@ func updateConfig(config *Config, input *OpenAIRequest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Can be either a string or an object
|
||||||
|
switch fnc := input.FunctionCall.(type) {
|
||||||
|
case string:
|
||||||
|
if fnc != "" {
|
||||||
|
config.functionCallString = fnc
|
||||||
|
}
|
||||||
|
case map[string]interface{}:
|
||||||
|
var name string
|
||||||
|
n, exists := fnc["name"]
|
||||||
|
if exists {
|
||||||
|
nn, e := n.(string)
|
||||||
|
if !e {
|
||||||
|
name = nn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.functionCallNameString = name
|
||||||
|
}
|
||||||
|
|
||||||
switch p := input.Prompt.(type) {
|
switch p := input.Prompt.(type) {
|
||||||
case string:
|
case string:
|
||||||
config.PromptStrings = append(config.PromptStrings, p)
|
config.PromptStrings = append(config.PromptStrings, p)
|
||||||
|
150
api/openai.go
150
api/openai.go
@ -17,6 +17,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
|
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
|
||||||
llama "github.com/go-skynet/go-llama.cpp"
|
llama "github.com/go-skynet/go-llama.cpp"
|
||||||
@ -73,8 +74,12 @@ type Choice struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role,omitempty" yaml:"role"`
|
// The message role
|
||||||
|
Role string `json:"role,omitempty" yaml:"role"`
|
||||||
|
// The message content
|
||||||
Content string `json:"content,omitempty" yaml:"content"`
|
Content string `json:"content,omitempty" yaml:"content"`
|
||||||
|
// A result of a function call
|
||||||
|
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIModel struct {
|
type OpenAIModel struct {
|
||||||
@ -104,6 +109,10 @@ type OpenAIRequest struct {
|
|||||||
// Messages is read only by chat/completion API calls
|
// Messages is read only by chat/completion API calls
|
||||||
Messages []Message `json:"messages" yaml:"messages"`
|
Messages []Message `json:"messages" yaml:"messages"`
|
||||||
|
|
||||||
|
// A list of available functions to call
|
||||||
|
Functions []grammar.Function `json:"functions" yaml:"functions"`
|
||||||
|
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||||
|
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Echo bool `json:"echo"`
|
Echo bool `json:"echo"`
|
||||||
// Common options between all the API calls
|
// Common options between all the API calls
|
||||||
@ -134,6 +143,9 @@ type OpenAIRequest struct {
|
|||||||
Mode int `json:"mode"`
|
Mode int `json:"mode"`
|
||||||
Step int `json:"step"`
|
Step int `json:"step"`
|
||||||
|
|
||||||
|
// A grammar to constrain the LLM output
|
||||||
|
Grammar string `json:"grammar" yaml:"grammar"`
|
||||||
|
|
||||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
TypicalP float64 `json:"typical_p" yaml:"typical_p"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -345,6 +357,23 @@ func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
||||||
|
// TODO: replace this with config settings
|
||||||
|
// Allow the user to set custom actions via config file
|
||||||
|
// to be "embedded" in each model
|
||||||
|
const noActionName = "answer"
|
||||||
|
const noActionDescription = "use this action to answer without performing any action"
|
||||||
|
|
||||||
|
noActionGrammar := grammar.Function{
|
||||||
|
Name: noActionName,
|
||||||
|
Description: noActionDescription,
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to reply the user with",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
|
||||||
initialMessage := OpenAIResponse{
|
initialMessage := OpenAIResponse{
|
||||||
@ -368,6 +397,8 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|||||||
close(responses)
|
close(responses)
|
||||||
}
|
}
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
processFunctions := false
|
||||||
|
funcs := []grammar.Function{}
|
||||||
model, input, err := readInput(c, o.loader, true)
|
model, input, err := readInput(c, o.loader, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
@ -377,8 +408,33 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
log.Debug().Msgf("Configuration read: %+v", config)
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
// process functions if we have any defined or if we have a function call string
|
||||||
|
if len(input.Functions) > 0 &&
|
||||||
|
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) {
|
||||||
|
log.Debug().Msgf("Response needs to process functions")
|
||||||
|
|
||||||
|
var funcs grammar.Functions = input.Functions
|
||||||
|
processFunctions = true
|
||||||
|
|
||||||
|
// Force picking one of the functions by the request
|
||||||
|
if config.functionCallNameString != "" {
|
||||||
|
funcs = funcs.Select(config.functionCallNameString)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the no action function
|
||||||
|
funcs = append(funcs, noActionGrammar)
|
||||||
|
|
||||||
|
// Update input grammar
|
||||||
|
jsStruct := funcs.ToJSONStructure()
|
||||||
|
config.Grammar = jsStruct.Grammar("")
|
||||||
|
}
|
||||||
|
|
||||||
|
// functions are not supported in stream mode (yet?)
|
||||||
|
toStream := input.Stream && !processFunctions
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameters: %+v", config)
|
||||||
|
|
||||||
var predInput string
|
var predInput string
|
||||||
|
|
||||||
@ -397,7 +453,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
predInput = strings.Join(mess, "\n")
|
predInput = strings.Join(mess, "\n")
|
||||||
|
|
||||||
if input.Stream {
|
if toStream {
|
||||||
log.Debug().Msgf("Stream request received")
|
log.Debug().Msgf("Stream request received")
|
||||||
c.Context().SetContentType("text/event-stream")
|
c.Context().SetContentType("text/event-stream")
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
@ -409,20 +465,35 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
templateFile := config.Model
|
templateFile := config.Model
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" {
|
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||||
templateFile = config.TemplateConfig.Chat
|
templateFile = config.TemplateConfig.Chat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||||
|
templateFile = config.TemplateConfig.Functions
|
||||||
|
}
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
|
||||||
Input string
|
Input string
|
||||||
}{Input: predInput})
|
Functions []grammar.Function
|
||||||
|
}{
|
||||||
|
Input: predInput,
|
||||||
|
Functions: funcs,
|
||||||
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
predInput = templatedInput
|
predInput = templatedInput
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Stream {
|
log.Debug().Msgf("Prompt: %s", predInput)
|
||||||
|
if processFunctions {
|
||||||
|
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
if toStream {
|
||||||
responses := make(chan OpenAIResponse)
|
responses := make(chan OpenAIResponse)
|
||||||
|
|
||||||
go process(predInput, input, config, o.loader, responses)
|
go process(predInput, input, config, o.loader, responses)
|
||||||
@ -459,6 +530,71 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) {
|
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) {
|
||||||
|
if processFunctions {
|
||||||
|
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||||
|
ss := map[string]interface{}{}
|
||||||
|
json.Unmarshal([]byte(s), &ss)
|
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||||
|
|
||||||
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
|
func_name := ss["function"]
|
||||||
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
|
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
d, _ := json.Marshal(args)
|
||||||
|
|
||||||
|
ss["arguments"] = string(d)
|
||||||
|
ss["name"] = func_name
|
||||||
|
|
||||||
|
// if do nothing, reply with a message
|
||||||
|
if func_name == noActionName {
|
||||||
|
log.Debug().Msgf("nothing to do, computing a reply")
|
||||||
|
|
||||||
|
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||||
|
arguments := map[string]interface{}{}
|
||||||
|
json.Unmarshal([]byte(d), &arguments)
|
||||||
|
m, exists := arguments["message"]
|
||||||
|
if exists {
|
||||||
|
switch message := m.(type) {
|
||||||
|
case string:
|
||||||
|
if message != "" {
|
||||||
|
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||||
|
message = Finetune(*config, predInput, message)
|
||||||
|
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||||
|
|
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: message}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||||
|
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||||
|
// Note: This costs (in term of CPU) another computation
|
||||||
|
config.Grammar = ""
|
||||||
|
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction, err := predFunc()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("inference error: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction = Finetune(*config, predInput, prediction)
|
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: prediction}})
|
||||||
|
} else {
|
||||||
|
// otherwise reply with the function call
|
||||||
|
*c = append(*c, Choice{
|
||||||
|
FinishReason: "function_call",
|
||||||
|
Message: &Message{Role: "function", FunctionCall: ss},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -189,6 +189,8 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption
|
|||||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar))
|
||||||
|
|
||||||
if c.PromptCachePath != "" {
|
if c.PromptCachePath != "" {
|
||||||
// Create parent directory
|
// Create parent directory
|
||||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||||
|
50
pkg/grammar/functions.go
Normal file
50
pkg/grammar/functions.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package grammar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters map[string]interface{} `json:"parameters"`
|
||||||
|
}
|
||||||
|
type Functions []Function
|
||||||
|
|
||||||
|
func (f Functions) ToJSONStructure() JSONStructure {
|
||||||
|
js := JSONStructure{}
|
||||||
|
for _, function := range f {
|
||||||
|
// t := function.Parameters["type"]
|
||||||
|
//tt := t.(string)
|
||||||
|
|
||||||
|
properties := function.Parameters["properties"]
|
||||||
|
dat, _ := json.Marshal(properties)
|
||||||
|
prop := map[string]interface{}{}
|
||||||
|
json.Unmarshal(dat, &prop)
|
||||||
|
js.OneOf = append(js.OneOf, Item{
|
||||||
|
Type: "object",
|
||||||
|
Properties: Properties{
|
||||||
|
Function: FunctionName{Const: function.Name},
|
||||||
|
Arguments: Argument{
|
||||||
|
Type: "object",
|
||||||
|
Properties: prop,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return js
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select returns a list of functions containing the function with the given name
|
||||||
|
func (f Functions) Select(name string) Functions {
|
||||||
|
var funcs Functions
|
||||||
|
|
||||||
|
for _, f := range f {
|
||||||
|
if f.Name == name {
|
||||||
|
funcs = []Function{f}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return funcs
|
||||||
|
}
|
13
pkg/grammar/grammar_suite_test.go
Normal file
13
pkg/grammar/grammar_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package grammar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGrammar(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Grammar test suite")
|
||||||
|
}
|
222
pkg/grammar/json_schema.go
Normal file
222
pkg/grammar/json_schema.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package grammar
|
||||||
|
|
||||||
|
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
SPACE_RULE = `" "?`
|
||||||
|
|
||||||
|
PRIMITIVE_RULES = map[string]string{
|
||||||
|
"boolean": `("true" | "false") space`,
|
||||||
|
"number": `[0-9]+ space`, // TODO complete
|
||||||
|
"string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete
|
||||||
|
"null": `"null" space`,
|
||||||
|
}
|
||||||
|
|
||||||
|
INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
|
||||||
|
GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`)
|
||||||
|
GRAMMAR_LITERAL_ESCAPES = map[string]string{
|
||||||
|
"\r": `\r`,
|
||||||
|
"\n": `\n`,
|
||||||
|
`"`: `\"`,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type JSONSchemaConverter struct {
|
||||||
|
propOrder map[string]int
|
||||||
|
rules map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter {
|
||||||
|
propOrderSlice := strings.Split(propOrder, ",")
|
||||||
|
propOrderMap := make(map[string]int)
|
||||||
|
for idx, name := range propOrderSlice {
|
||||||
|
propOrderMap[name] = idx
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := make(map[string]string)
|
||||||
|
rules["space"] = SPACE_RULE
|
||||||
|
|
||||||
|
return &JSONSchemaConverter{
|
||||||
|
propOrder: propOrderMap,
|
||||||
|
rules: rules,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string {
|
||||||
|
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string {
|
||||||
|
return GRAMMAR_LITERAL_ESCAPES[match]
|
||||||
|
})
|
||||||
|
return fmt.Sprintf(`"%s"`, escaped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) addRule(name, rule string) string {
|
||||||
|
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-")
|
||||||
|
key := escName
|
||||||
|
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule {
|
||||||
|
i := 0
|
||||||
|
for {
|
||||||
|
key = fmt.Sprintf("%s%d", escName, i)
|
||||||
|
if _, ok := sc.rules[key]; !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sc.rules[key] = rule
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) formatGrammar() string {
|
||||||
|
var lines []string
|
||||||
|
for name, rule := range sc.rules {
|
||||||
|
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string {
|
||||||
|
st, existType := schema["type"]
|
||||||
|
var schemaType string
|
||||||
|
if existType {
|
||||||
|
schemaType = st.(string)
|
||||||
|
}
|
||||||
|
ruleName := name
|
||||||
|
if name == "" {
|
||||||
|
ruleName = "root"
|
||||||
|
}
|
||||||
|
_, oneOfExists := schema["oneOf"]
|
||||||
|
_, anyOfExists := schema["anyOf"]
|
||||||
|
if oneOfExists || anyOfExists {
|
||||||
|
var alternatives []string
|
||||||
|
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{})
|
||||||
|
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{})
|
||||||
|
|
||||||
|
if oneOfExists {
|
||||||
|
for i, altSchema := range oneOfSchemas {
|
||||||
|
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
|
||||||
|
alternatives = append(alternatives, alternative)
|
||||||
|
}
|
||||||
|
} else if anyOfExists {
|
||||||
|
for i, altSchema := range anyOfSchemas {
|
||||||
|
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i))
|
||||||
|
alternatives = append(alternatives, alternative)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := strings.Join(alternatives, " | ")
|
||||||
|
return sc.addRule(ruleName, rule)
|
||||||
|
} else if constVal, exists := schema["const"]; exists {
|
||||||
|
return sc.addRule(ruleName, sc.formatLiteral(constVal))
|
||||||
|
} else if enumVals, exists := schema["enum"].([]interface{}); exists {
|
||||||
|
var enumRules []string
|
||||||
|
for _, enumVal := range enumVals {
|
||||||
|
enumRule := sc.formatLiteral(enumVal)
|
||||||
|
enumRules = append(enumRules, enumRule)
|
||||||
|
}
|
||||||
|
rule := strings.Join(enumRules, " | ")
|
||||||
|
return sc.addRule(ruleName, rule)
|
||||||
|
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists {
|
||||||
|
propOrder := sc.propOrder
|
||||||
|
var propPairs []struct {
|
||||||
|
propName string
|
||||||
|
propSchema map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for propName, propSchema := range properties {
|
||||||
|
propPairs = append(propPairs, struct {
|
||||||
|
propName string
|
||||||
|
propSchema map[string]interface{}
|
||||||
|
}{propName: propName, propSchema: propSchema.(map[string]interface{})})
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(propPairs, func(i, j int) bool {
|
||||||
|
iOrder := propOrder[propPairs[i].propName]
|
||||||
|
jOrder := propOrder[propPairs[j].propName]
|
||||||
|
if iOrder != 0 && jOrder != 0 {
|
||||||
|
return iOrder < jOrder
|
||||||
|
}
|
||||||
|
return propPairs[i].propName < propPairs[j].propName
|
||||||
|
})
|
||||||
|
|
||||||
|
var rule strings.Builder
|
||||||
|
rule.WriteString(`"{" space`)
|
||||||
|
|
||||||
|
for i, propPair := range propPairs {
|
||||||
|
propName := propPair.propName
|
||||||
|
propSchema := propPair.propSchema
|
||||||
|
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName))
|
||||||
|
|
||||||
|
if i > 0 {
|
||||||
|
rule.WriteString(` "," space`)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName))
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.WriteString(` "}" space`)
|
||||||
|
return sc.addRule(ruleName, rule.String())
|
||||||
|
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
|
||||||
|
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName))
|
||||||
|
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
|
||||||
|
return sc.addRule(ruleName, rule)
|
||||||
|
} else {
|
||||||
|
primitiveRule, exists := PRIMITIVE_RULES[schemaType]
|
||||||
|
if !exists {
|
||||||
|
panic(fmt.Sprintf("Unrecognized schema: %v", schema))
|
||||||
|
}
|
||||||
|
return sc.addRule(schemaType, primitiveRule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
|
||||||
|
sc.visit(schema, "")
|
||||||
|
return sc.formatGrammar()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string {
|
||||||
|
var schema map[string]interface{}
|
||||||
|
_ = json.Unmarshal(b, &schema)
|
||||||
|
return sc.Grammar(schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
func jsonString(v interface{}) string {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionName struct {
|
||||||
|
Const string `json:"const"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Properties struct {
|
||||||
|
Function FunctionName `json:"function"`
|
||||||
|
Arguments Argument `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Argument struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]interface{} `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Item struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties Properties `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JSONStructure struct {
|
||||||
|
OneOf []Item `json:"oneOf,omitempty"`
|
||||||
|
AnyOf []Item `json:"anyOf,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j JSONStructure) Grammar(propOrder string) string {
|
||||||
|
dat, _ := json.Marshal(j)
|
||||||
|
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat)
|
||||||
|
}
|
113
pkg/grammar/json_schema_test.go
Normal file
113
pkg/grammar/json_schema_test.go
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
package grammar_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
. "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testInput1 = `
|
||||||
|
{
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"function": {"const": "create_event"},
|
||||||
|
"arguments": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {"type": "string"},
|
||||||
|
"date": {"type": "string"},
|
||||||
|
"time": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"function": {"const": "search"},
|
||||||
|
"arguments": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
inputResult1 = `root-0-function ::= "\"create_event\""
|
||||||
|
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space
|
||||||
|
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space
|
||||||
|
root ::= root-0 | root-1
|
||||||
|
space ::= " "?
|
||||||
|
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space
|
||||||
|
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space
|
||||||
|
string ::= "\"" [ \t!#-\[\]-~]* "\"" space
|
||||||
|
root-1-function ::= "\"search\""`
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("JSON schema grammar tests", func() {
|
||||||
|
Context("JSON", func() {
|
||||||
|
It("generates a valid grammar from JSON schema", func() {
|
||||||
|
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1))
|
||||||
|
results := strings.Split(inputResult1, "\n")
|
||||||
|
for _, r := range results {
|
||||||
|
if r != "" {
|
||||||
|
Expect(grammar).To(ContainSubstring(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||||
|
})
|
||||||
|
It("generates a valid grammar from JSON Objects", func() {
|
||||||
|
|
||||||
|
structuredGrammar := JSONStructure{
|
||||||
|
OneOf: []Item{
|
||||||
|
{
|
||||||
|
Type: "object",
|
||||||
|
Properties: Properties{
|
||||||
|
Function: FunctionName{
|
||||||
|
Const: "create_event",
|
||||||
|
},
|
||||||
|
Arguments: Argument{ // this is OpenAI's parameter
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"title": map[string]string{"type": "string"},
|
||||||
|
"date": map[string]string{"type": "string"},
|
||||||
|
"time": map[string]string{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "object",
|
||||||
|
Properties: Properties{
|
||||||
|
Function: FunctionName{
|
||||||
|
Const: "search",
|
||||||
|
},
|
||||||
|
Arguments: Argument{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"query": map[string]string{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
grammar := structuredGrammar.Grammar("")
|
||||||
|
results := strings.Split(inputResult1, "\n")
|
||||||
|
for _, r := range results {
|
||||||
|
if r != "" {
|
||||||
|
Expect(grammar).To(ContainSubstring(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
Loading…
Reference in New Issue
Block a user