mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
tests: add template tests (#2063)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
502c1eedaa
commit
f9c75d4878
105
pkg/model/loader_test.go
Normal file
105
pkg/model/loader_test.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
. "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
const chatML = `<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "tool"}}tool{{else if eq .RoleName "user"}}user{{end}}
|
||||||
|
{{- if .FunctionCall }}
|
||||||
|
<tool_call>
|
||||||
|
{{- else if eq .RoleName "tool" }}
|
||||||
|
<tool_response>
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Content}}
|
||||||
|
{{.Content }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .FunctionCall}}
|
||||||
|
{{toJson .FunctionCall}}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .FunctionCall }}
|
||||||
|
</tool_call>
|
||||||
|
{{- else if eq .RoleName "tool" }}
|
||||||
|
</tool_response>
|
||||||
|
{{- end }}
|
||||||
|
<|im_end|>`
|
||||||
|
|
||||||
|
var testMatch map[string]map[string]interface{} = map[string]map[string]interface{}{
|
||||||
|
"user": {
|
||||||
|
"template": chatML,
|
||||||
|
"expected": "<|im_start|>user\nA long time ago in a galaxy far, far away...\n<|im_end|>",
|
||||||
|
"data": model.ChatMessageTemplateData{
|
||||||
|
SystemPrompt: "",
|
||||||
|
Role: "user",
|
||||||
|
RoleName: "user",
|
||||||
|
Content: "A long time ago in a galaxy far, far away...",
|
||||||
|
FunctionCall: nil,
|
||||||
|
FunctionName: "",
|
||||||
|
LastMessage: false,
|
||||||
|
Function: false,
|
||||||
|
MessageIndex: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"template": chatML,
|
||||||
|
"expected": "<|im_start|>assistant\nA long time ago in a galaxy far, far away...\n<|im_end|>",
|
||||||
|
"data": model.ChatMessageTemplateData{
|
||||||
|
SystemPrompt: "",
|
||||||
|
Role: "assistant",
|
||||||
|
RoleName: "assistant",
|
||||||
|
Content: "A long time ago in a galaxy far, far away...",
|
||||||
|
FunctionCall: nil,
|
||||||
|
FunctionName: "",
|
||||||
|
LastMessage: false,
|
||||||
|
Function: false,
|
||||||
|
MessageIndex: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"function_call": {
|
||||||
|
"template": chatML,
|
||||||
|
"expected": "<|im_start|>assistant\n<tool_call>\n{\"function\":\"test\"}\n</tool_call>\n<|im_end|>",
|
||||||
|
"data": model.ChatMessageTemplateData{
|
||||||
|
SystemPrompt: "",
|
||||||
|
Role: "assistant",
|
||||||
|
RoleName: "assistant",
|
||||||
|
Content: "",
|
||||||
|
FunctionCall: map[string]string{"function": "test"},
|
||||||
|
FunctionName: "",
|
||||||
|
LastMessage: false,
|
||||||
|
Function: false,
|
||||||
|
MessageIndex: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"function_response": {
|
||||||
|
"template": chatML,
|
||||||
|
"expected": "<|im_start|>tool\n<tool_response>\nResponse from tool\n</tool_response>\n<|im_end|>",
|
||||||
|
"data": model.ChatMessageTemplateData{
|
||||||
|
SystemPrompt: "",
|
||||||
|
Role: "tool",
|
||||||
|
RoleName: "tool",
|
||||||
|
Content: "Response from tool",
|
||||||
|
FunctionCall: nil,
|
||||||
|
FunctionName: "",
|
||||||
|
LastMessage: false,
|
||||||
|
Function: false,
|
||||||
|
MessageIndex: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("Templates", func() {
|
||||||
|
Context("chat message", func() {
|
||||||
|
modelLoader := NewModelLoader("")
|
||||||
|
for key := range testMatch {
|
||||||
|
foo := testMatch[key]
|
||||||
|
It("renders correctly "+key, func() {
|
||||||
|
templated, err := modelLoader.EvaluateTemplateForChatMessage(foo["template"].(string), foo["data"].(model.ChatMessageTemplateData))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(templated).To(Equal(foo["expected"]), templated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
13
pkg/model/model_suite_test.go
Normal file
13
pkg/model/model_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModel(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "LocalAI model test")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user