fix: set default rope if not specified

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-07-29 00:04:25 +02:00
parent fa4de05c14
commit f085baa77d
2 changed files with 20 additions and 12 deletions

View File

@ -30,10 +30,10 @@ import (
) )
type modelApplyRequest struct { type modelApplyRequest struct {
ID string `json:"id"` ID string `json:"id"`
URL string `json:"url"` URL string `json:"url"`
Name string `json:"name"` Name string `json:"name"`
Overrides map[string]string `json:"overrides"` Overrides map[string]interface{} `json:"overrides"`
} }
func getModelStatus(url string) (response map[string]interface{}) { func getModelStatus(url string) (response map[string]interface{}) {
@ -243,7 +243,7 @@ var _ = Describe("API test", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert", Name: "bert",
Overrides: map[string]string{ Overrides: map[string]interface{}{
"backend": "llama", "backend": "llama",
}, },
}) })
@ -269,7 +269,7 @@ var _ = Describe("API test", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert", Name: "bert",
Overrides: map[string]string{}, Overrides: map[string]interface{}{},
}) })
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
@ -297,7 +297,7 @@ var _ = Describe("API test", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/openllama_3b.yaml", URL: "github:go-skynet/model-gallery/openllama_3b.yaml",
Name: "openllama_3b", Name: "openllama_3b",
Overrides: map[string]string{"backend": "llama"}, Overrides: map[string]interface{}{"backend": "llama", "mmap": true, "f16": true},
}) })
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
@ -366,9 +366,8 @@ var _ = Describe("API test", func() {
} }
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "github:go-skynet/model-gallery/gpt4all-j.yaml", URL: "github:go-skynet/model-gallery/gpt4all-j.yaml",
Name: "gpt4all-j", Name: "gpt4all-j",
Overrides: map[string]string{},
}) })
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))

View File

@ -58,6 +58,15 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
} }
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
ropeFreqBase := float32(1000)
ropeFreqScale := float32(1)
if opts.RopeFreqBase != 0 {
ropeFreqBase = opts.RopeFreqBase
}
if opts.RopeFreqScale != 0 {
ropeFreqScale = opts.RopeFreqScale
}
predictOptions := []llama.PredictOption{ predictOptions := []llama.PredictOption{
llama.SetTemperature(opts.Temperature), llama.SetTemperature(opts.Temperature),
llama.SetTopP(opts.TopP), llama.SetTopP(opts.TopP),
@ -65,8 +74,8 @@ func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
llama.SetTokens(int(opts.Tokens)), llama.SetTokens(int(opts.Tokens)),
llama.SetThreads(int(opts.Threads)), llama.SetThreads(int(opts.Threads)),
llama.WithGrammar(opts.Grammar), llama.WithGrammar(opts.Grammar),
llama.SetRopeFreqBase(opts.RopeFreqBase), llama.SetRopeFreqBase(ropeFreqBase),
llama.SetRopeFreqScale(opts.RopeFreqScale), llama.SetRopeFreqScale(ropeFreqScale),
llama.SetNegativePromptScale(opts.NegativePromptScale), llama.SetNegativePromptScale(opts.NegativePromptScale),
llama.SetNegativePrompt(opts.NegativePrompt), llama.SetNegativePrompt(opts.NegativePrompt),
} }