diff --git a/api/config/config.go b/api/config/config.go index 9df8d3e0..2d4ab2ce 100644 --- a/api/config/config.go +++ b/api/config/config.go @@ -49,6 +49,8 @@ type Config struct { functionCallString, functionCallNameString string FunctionsConfig Functions `yaml:"function"` + + SystemPrompt string `yaml:"system_prompt"` } type Functions struct { @@ -58,10 +60,11 @@ type Functions struct { } type TemplateConfig struct { - Completion string `yaml:"completion"` - Functions string `yaml:"function"` - Chat string `yaml:"chat"` - Edit string `yaml:"edit"` + Chat string `yaml:"chat"` + ChatMessage string `yaml:"chat_message"` + Completion string `yaml:"completion"` + Edit string `yaml:"edit"` + Functions string `yaml:"function"` } type ConfigLoader struct { diff --git a/api/openai/chat.go b/api/openai/chat.go index d85bf7b3..a9cbd240 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -43,12 +43,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return func(c *fiber.Ctx) error { processFunctions := false funcs := grammar.Functions{} - model, input, err := readInput(c, o.Loader, true) + modelFile, input, err := readInput(c, o.Loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -110,9 +110,10 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) var predInput string mess := []string{} - for _, i := range input.Messages { + for messageIndex, i := range input.Messages { var content string role := i.Role + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request if i.FunctionCall != nil && i.Role == "assistant" { @@ -124,31 +125,55 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } r := config.Roles[role] contentExists := i.Content != nil && *i.Content != "" - if r != "" { - if contentExists { - content = fmt.Sprint(r, " ", *i.Content) + // First attempt to populate content via a chat message specific template + if config.TemplateConfig.ChatMessage != "" { + chatMessageData := model.ChatMessageTemplateData{ + SystemPrompt: config.SystemPrompt, + Role: r, + RoleName: role, + Content: *i.Content, + MessageIndex: messageIndex, } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) + templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + if err != nil { + log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err) + } else { + if templatedChatMessage == "" { + log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) + continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + } + log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) + content = templatedChatMessage + } + } + // If this model doesn't have such a template, or if + if content == "" { + if r != "" { + if contentExists { + content = fmt.Sprint(r, " ", *i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } } } - } - } else { - if contentExists { - content = fmt.Sprint(*i.Content) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) + } else { + if contentExists { + content = fmt.Sprint(*i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } } } } @@ -181,10 +206,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { - Input string - Functions []grammar.Function - }{ + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ Input: predInput, Functions: funcs, }) diff --git a/api/openai/completion.go b/api/openai/completion.go index 4b38c30c..1efe37c7 100644 --- a/api/openai/completion.go +++ b/api/openai/completion.go @@ -38,14 +38,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe } return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.Loader, true) + modelFile, input, err := readInput(c, o.Loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } log.Debug().Msgf("`input`: %+v", input) - config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -76,9 +76,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe predInput := config.PromptStrings[0] // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { - Input string - }{ + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ Input: predInput, }) if err == nil { @@ -124,9 +122,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe var result []Choice for k, i := range config.PromptStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { - Input string - }{ + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ Input: i, }) if err == nil { diff --git a/api/openai/edit.go b/api/openai/edit.go index d988d6d1..459c9748 100644 --- a/api/openai/edit.go +++ b/api/openai/edit.go @@ -6,18 +6,19 @@ import ( config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.Loader, true) + modelFile, input, err := readInput(c, o.Loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -33,10 +34,10 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) var result []Choice for _, i := range config.InputStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { - Input string - Instruction string - }{Input: i}) + templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ + Input: i, + Instruction: input.Instruction, + }) if err == nil { i = templatedInput log.Debug().Msgf("Template found, input modified to: %s", i) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 08bf6c4d..86ee554d 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -128,7 +128,7 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string // It also loads the model func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { return func(s string) (*grpc.Client, error) { - log.Debug().Msgf("Loading GRPC Model", backend, *o) + log.Debug().Msgf("Loading GRPC Model %s: %+v", backend, *o) var client *grpc.Client diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 833c3115..bb49a7cc 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -4,43 +4,81 @@ import ( "bytes" "context" "fmt" - "io/ioutil" "os" "path/filepath" "strings" "sync" "text/template" + grammar "github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grpc" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) +// Rather than pass an interface{} to the prompt template: +// These are the definitions of all possible variables LocalAI will currently populate for use in a prompt template file +// Please note: Not all of these are populated on every endpoint - your template should either be tested for each endpoint you map it to, or tolerant of zero values. +type PromptTemplateData struct { + Input string + Instruction string + Functions []grammar.Function + MessageIndex int +} + +// TODO: Ask mudler about FunctionCall stuff being useful at the message level? +type ChatMessageTemplateData struct { + SystemPrompt string + Role string + RoleName string + Content string + MessageIndex int +} + +// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go? +// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go +type TemplateType int + +const ( + ChatPromptTemplate TemplateType = iota + ChatMessageTemplate + CompletionPromptTemplate + EditPromptTemplate + FunctionsPromptTemplate + + // The following TemplateType is **NOT** a valid value and MUST be last. It exists to make the sanity integration tests simpler! + IntegrationTestTemplate +) + +// new idea: what if we declare a struct of these here, and use a loop to check? + +// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl type ModelLoader struct { ModelPath string mu sync.Mutex // TODO: this needs generics - models map[string]*grpc.Client - grpcProcesses map[string]*process.Process - promptsTemplates map[string]*template.Template + models map[string]*grpc.Client + grpcProcesses map[string]*process.Process + templates map[TemplateType]map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { - return &ModelLoader{ - ModelPath: modelPath, - models: make(map[string]*grpc.Client), - promptsTemplates: make(map[string]*template.Template), - grpcProcesses: make(map[string]*process.Process), + nml := &ModelLoader{ + ModelPath: modelPath, + models: make(map[string]*grpc.Client), + templates: make(map[TemplateType]map[string]*template.Template), + grpcProcesses: make(map[string]*process.Process), } + nml.initializeTemplateMap() + return nml } func (ml *ModelLoader) ExistsInModelPath(s string) bool { - _, err := os.Stat(filepath.Join(ml.ModelPath, s)) - return err == nil + return existsInPath(ml.ModelPath, s) } func (ml *ModelLoader) ListModels() ([]string, error) { - files, err := ioutil.ReadDir(ml.ModelPath) + files, err := os.ReadDir(ml.ModelPath) if err != nil { return []string{}, err } @@ -58,63 +96,6 @@ func (ml *ModelLoader) ListModels() ([]string, error) { return models, nil } -func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) { - ml.mu.Lock() - defer ml.mu.Unlock() - - m, ok := ml.promptsTemplates[modelName] - if !ok { - modelFile := filepath.Join(ml.ModelPath, modelName) - if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { - return "", err - } - - t, exists := ml.promptsTemplates[modelName] - if exists { - m = t - } - } - if m == nil { - return "", fmt.Errorf("failed loading any template") - } - - var buf bytes.Buffer - - if err := m.Execute(&buf, in); err != nil { - return "", err - } - return buf.String(), nil -} - -func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { - // Check if the template was already loaded - if _, ok := ml.promptsTemplates[modelName]; ok { - return nil - } - - // Check if the model path exists - // skip any error here - we run anyway if a template does not exist - modelTemplateFile := fmt.Sprintf("%s.tmpl", modelName) - - if !ml.ExistsInModelPath(modelTemplateFile) { - return nil - } - - dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile)) - if err != nil { - return err - } - - // Parse the template - tmpl, err := template.New("prompt").Parse(string(dat)) - if err != nil { - return err - } - ml.promptsTemplates[modelName] = tmpl - - return nil -} - func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Client, error)) (*grpc.Client, error) { ml.mu.Lock() defer ml.mu.Unlock() @@ -134,10 +115,13 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Cl return nil, err } - // If there is a prompt template, load it - if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { - return nil, err - } + // TODO: Add a helper method to iterate all prompt templates associated with a config if and only if it's YAML? + // Minor perf loss here until this is fixed, but we initialize on first request + + // // If there is a prompt template, load it + // if err := ml.loadTemplateIfExists(modelName); err != nil { + // return nil, err + // } ml.models[modelName] = model return model, nil @@ -148,9 +132,9 @@ func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client { log.Debug().Msgf("Model already loaded in memory: %s", s) if !m.HealthCheck(context.Background()) { - log.Debug().Msgf("GRPC Model not responding", s) + log.Debug().Msgf("GRPC Model not responding: %s", s) if !ml.grpcProcesses[s].IsAlive() { - log.Debug().Msgf("GRPC Process is not responding", s) + log.Debug().Msgf("GRPC Process is not responding: %s", s) // stop and delete the process, this forces to re-load the model and re-create again the service ml.grpcProcesses[s].Stop() delete(ml.grpcProcesses, s) @@ -164,3 +148,81 @@ func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client { return nil } + +func (ml *ModelLoader) EvaluateTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { + // TODO: should this check be improved? + if templateType == ChatMessageTemplate { + return "", fmt.Errorf("invalid templateType: ChatMessage") + } + return ml.evaluateTemplate(templateType, templateName, in) +} + +func (ml *ModelLoader) EvaluateTemplateForChatMessage(templateName string, messageData ChatMessageTemplateData) (string, error) { + return ml.evaluateTemplate(ChatMessageTemplate, templateName, messageData) +} + +func existsInPath(path string, s string) bool { + _, err := os.Stat(filepath.Join(path, s)) + return err == nil +} + +func (ml *ModelLoader) initializeTemplateMap() { + // This also seems somewhat clunky as we reference the Test / End of valid data value slug, but it works? + for tt := TemplateType(0); tt < IntegrationTestTemplate; tt++ { + ml.templates[tt] = make(map[string]*template.Template) + } +} + +func (ml *ModelLoader) evaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + m, ok := ml.templates[templateType][templateName] + if !ok { + // return "", fmt.Errorf("template not loaded: %s", templateName) + loadErr := ml.loadTemplateIfExists(templateType, templateName) + if loadErr != nil { + return "", loadErr + } + m = ml.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked + } + if m == nil { + return "", fmt.Errorf("failed loading a template for %s", templateName) + } + + var buf bytes.Buffer + + if err := m.Execute(&buf, in); err != nil { + return "", err + } + return buf.String(), nil +} + +func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateName string) error { + // Check if the template was already loaded + if _, ok := ml.templates[templateType][templateName]; ok { + return nil + } + + // Check if the model path exists + // skip any error here - we run anyway if a template does not exist + modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) + + if !ml.ExistsInModelPath(modelTemplateFile) { + return nil + } + + dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile)) + if err != nil { + return err + } + + // Parse the template + tmpl, err := template.New("prompt").Parse(string(dat)) + if err != nil { + return err + } + ml.templates[templateType][templateName] = tmpl + + return nil +} diff --git a/prompt-templates/llama2-chat-message.tmpl b/prompt-templates/llama2-chat-message.tmpl new file mode 100644 index 00000000..e99efe82 --- /dev/null +++ b/prompt-templates/llama2-chat-message.tmpl @@ -0,0 +1,7 @@ +{{if eq .RoleName "assistant"}}{{.Content}}{{else}} +[INST] +{{if .SystemPrompt}}{{.SystemPrompt}}{{else if eq .RoleName "system"}}<>{{.Content}}<> + +{{else if .Content}}{{.Content}}{{end}} +[/INST] +{{end}} \ No newline at end of file diff --git a/tests/integration/reflect_test.go b/tests/integration/reflect_test.go new file mode 100644 index 00000000..c0fe7096 --- /dev/null +++ b/tests/integration/reflect_test.go @@ -0,0 +1,23 @@ +package integration_test + +import ( + "reflect" + + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Integration Tests involving reflection in liue of code generation", func() { + Context("config.TemplateConfig and model.TemplateType must stay in sync", func() { + + ttc := reflect.TypeOf(config.TemplateConfig{}) + + It("TemplateConfig and TemplateType should have the same number of valid values", func() { + const lastValidTemplateType = model.IntegrationTestTemplate - 1 + Expect(lastValidTemplateType).To(Equal(ttc.NumField())) + }) + + }) +})