diff --git a/api/api_test.go b/api/api_test.go index a14a7dcb..199ef14a 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -79,7 +79,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: llama: model does not exist")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 4 errors occurred:")) }) }) diff --git a/api/config.go b/api/config.go index ea4b335b..c57debdc 100644 --- a/api/config.go +++ b/api/config.go @@ -21,6 +21,7 @@ type Config struct { Threads int `yaml:"threads"` Debug bool `yaml:"debug"` Roles map[string]string `yaml:"roles"` + Backend string `yaml:"backend"` TemplateConfig TemplateConfig `yaml:"template"` } diff --git a/api/openai.go b/api/openai.go index bbb26ea7..63b2b32d 100644 --- a/api/openai.go +++ b/api/openai.go @@ -119,7 +119,9 @@ func updateConfig(config *Config, input *OpenAIRequest) { switch stop := input.Stop.(type) { case string: - config.StopWords = append(config.StopWords, stop) + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } case []string: config.StopWords = append(config.StopWords, stop...) diff --git a/api/prediction.go b/api/prediction.go index 4f01abbb..4d2f77ca 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -10,22 +10,86 @@ import ( gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" llama "github.com/go-skynet/go-llama.cpp" + "github.com/hashicorp/go-multierror" ) // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 var mutexMap sync.Mutex var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) +var loadedModels map[string]interface{} = map[string]interface{}{} +var muModels sync.Mutex + +func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) { + switch strings.ToLower(backendString) { + case "llama": + return loader.LoadLLaMAModel(modelFile, llamaOpts...) + case "stablelm": + return loader.LoadStableLMModel(modelFile) + case "gpt2": + return loader.LoadGPT2Model(modelFile) + case "gptj": + return loader.LoadGPTJModel(modelFile) + default: + return nil, fmt.Errorf("backend unsupported: %s", backendString) + } +} + +func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) { + updateModels := func(model interface{}) { + muModels.Lock() + defer muModels.Unlock() + loadedModels[modelFile] = model + } + + muModels.Lock() + m, exists := loadedModels[modelFile] + if exists { + muModels.Unlock() + return m, nil + } + muModels.Unlock() + + model, modelerr := loader.LoadLLaMAModel(modelFile, llamaOpts...) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + + model, modelerr = loader.LoadGPTJModel(modelFile) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + + model, modelerr = loader.LoadGPT2Model(modelFile) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + + model, modelerr = loader.LoadStableLMModel(modelFile) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + + return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error()) +} + func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) { - var model *llama.LLama - var gptModel *gptj.GPTJ - var gpt2Model *gpt2.GPT2 - var stableLMModel *gpt2.StableLM supportStreams := false modelFile := c.Model // Try to load the model - var llamaerr, gpt2err, gptjerr, stableerr error llamaOpts := []llama.ModelOption{} if c.ContextSize != 0 { llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) @@ -34,25 +98,21 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback llamaOpts = append(llamaOpts, llama.EnableF16Memory) } - // TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation.. - model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) - if llamaerr != nil { - gptModel, gptjerr = loader.LoadGPTJModel(modelFile) - if gptjerr != nil { - gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) - if gpt2err != nil { - stableLMModel, stableerr = loader.LoadStableLMModel(modelFile) - if stableerr != nil { - return nil, fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors - } - } - } + var inferenceModel interface{} + var err error + if c.Backend == "" { + inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts) + } else { + inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts) + } + if err != nil { + return nil, err } var fn func() (string, error) - switch { - case stableLMModel != nil: + switch model := inferenceModel.(type) { + case *gpt2.StableLM: fn = func() (string, error) { // Generate the prediction using the language model predictOptions := []gpt2.PredictOption{ @@ -71,12 +131,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) } - return stableLMModel.Predict( + return model.Predict( s, predictOptions..., ) } - case gpt2Model != nil: + case *gpt2.GPT2: fn = func() (string, error) { // Generate the prediction using the language model predictOptions := []gpt2.PredictOption{ @@ -95,12 +155,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) } - return gpt2Model.Predict( + return model.Predict( s, predictOptions..., ) } - case gptModel != nil: + case *gptj.GPTJ: fn = func() (string, error) { // Generate the prediction using the language model predictOptions := []gptj.PredictOption{ @@ -119,12 +179,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback predictOptions = append(predictOptions, gptj.SetSeed(c.Seed)) } - return gptModel.Predict( + return model.Predict( s, predictOptions..., ) } - case model != nil: + case *llama.LLama: supportStreams = true fn = func() (string, error) { diff --git a/go.mod b/go.mod index 50b9797e..2f017906 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c github.com/go-skynet/go-llama.cpp v0.0.0-20230430075552-377fd245eae2 github.com/gofiber/fiber/v2 v2.44.0 + github.com/hashicorp/go-multierror v1.1.1 github.com/jaypipes/ghw v0.10.0 github.com/onsi/ginkgo/v2 v2.9.2 github.com/onsi/gomega v1.27.6 @@ -29,6 +30,7 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/google/uuid v1.3.0 // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect github.com/jaypipes/pcidb v1.0.0 // indirect github.com/klauspost/compress v1.16.3 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index ceb3b667..528e445c 100644 --- a/go.sum +++ b/go.sum @@ -23,14 +23,6 @@ github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708 h1:cfOi4TWvQ github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230501160437-8417608f0e94 h1:klkEyXTg7bpchNNpIQH1f2wX/C17lFLti8isCCC3mYo= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230501160437-8417608f0e94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230428071219-3d084e4299e9 h1:N/0SBefkMFao6GiGhIF7+5EdYOMHn4KnCG2AFcIXPt0= -github.com/go-skynet/go-llama.cpp v0.0.0-20230428071219-3d084e4299e9/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= -github.com/go-skynet/go-llama.cpp v0.0.0-20230429125915-9bf702fe56b9 h1:20/tdOA4+b7Y7lCob+q2sczfOSz0pp+14L32adYJ+uQ= -github.com/go-skynet/go-llama.cpp v0.0.0-20230429125915-9bf702fe56b9/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= -github.com/go-skynet/go-llama.cpp v0.0.0-20230429225431-361b9f87de6d h1:7KDq1Uylm1mXphQ+M2qztekXAvODtXvJDHrXQguRw9k= -github.com/go-skynet/go-llama.cpp v0.0.0-20230429225431-361b9f87de6d/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= github.com/go-skynet/go-llama.cpp v0.0.0-20230430075552-377fd245eae2 h1:CYQRCbOfYtC77OxweAyrdxSVwoLIM/EdZ6Ij+xBzta8= github.com/go-skynet/go-llama.cpp v0.0.0-20230430075552-377fd245eae2/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= @@ -45,6 +37,10 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/jaypipes/ghw v0.10.0 h1:UHu9UX08Py315iPojADFPOkmjTsNzHj4g4adsNKKteY= github.com/jaypipes/ghw v0.10.0/go.mod h1:jeJGbkRB2lL3/gxYzNYzEDETV1ZJ56OKr+CSeSEym+g= @@ -88,8 +84,6 @@ github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.9.0 h1:NoiO++IISxxJ1pRc0n7uZvMGMake0G+FJ1XPwXtprsA= -github.com/sashabaranov/go-openai v1.9.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.9.1 h1:3N52HkJKo9Zlo/oe1AVv5ZkCOny0ra58/ACvAxkN3MM= github.com/sashabaranov/go-openai v1.9.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= @@ -103,16 +97,10 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= -github.com/urfave/cli/v2 v2.25.1 h1:zw8dSP7ghX0Gmm8vugrs6q9Ku0wzweqPyshy+syu9Gw= -github.com/urfave/cli/v2 v2.25.1/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= -github.com/urfave/cli/v2 v2.25.2 h1:rgeK7wmjwH+d3DqXDDSV20GZAvNzmzu/VEsg1om3Qwg= -github.com/urfave/cli/v2 v2.25.2/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/urfave/cli/v2 v2.25.3 h1:VJkt6wvEBOoSjPFQvOkv6iWIrsJyCrKGtCtxXWwmGeY= github.com/urfave/cli/v2 v2.25.3/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.46.0 h1:6ZRhrFg8zBXTRYY6vdzbFhqsBd7FVv123pV2m9V87U4= -github.com/valyala/fasthttp v1.46.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= github.com/valyala/fasthttp v1.47.0 h1:y7moDoxYzMooFpT5aHgNgVOQDrS3qlkfiP9mDtGGK9c= github.com/valyala/fasthttp v1.47.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 6b1539c5..34826d1a 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -168,13 +168,6 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) { return m, nil } - // TODO: This needs refactoring, it's really bad to have it in here - // Check if we have a GPTStable model loaded instead - if we do we return an error so the API tries with StableLM - if _, ok := ml.gptstablelmmodels[modelName]; ok { - log.Debug().Msgf("Model is GPTStableLM: %s", modelName) - return nil, fmt.Errorf("this model is a GPTStableLM one") - } - // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.ModelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) @@ -207,17 +200,6 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { return m, nil } - // TODO: This needs refactoring, it's really bad to have it in here - // Check if we have a GPT2 model loaded instead - if we do we return an error so the API tries with GPT2 - if _, ok := ml.gpt2models[modelName]; ok { - log.Debug().Msgf("Model is GPT2: %s", modelName) - return nil, fmt.Errorf("this model is a GPT2 one") - } - if _, ok := ml.gptstablelmmodels[modelName]; ok { - log.Debug().Msgf("Model is GPTStableLM: %s", modelName) - return nil, fmt.Errorf("this model is a GPTStableLM one") - } - // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.ModelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) @@ -252,21 +234,6 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio return m, nil } - // TODO: This needs refactoring, it's really bad to have it in here - // Check if we have a GPTJ model loaded instead - if we do we return an error so the API tries with GPTJ - if _, ok := ml.gptmodels[modelName]; ok { - log.Debug().Msgf("Model is GPTJ: %s", modelName) - return nil, fmt.Errorf("this model is a GPTJ one") - } - if _, ok := ml.gpt2models[modelName]; ok { - log.Debug().Msgf("Model is GPT2: %s", modelName) - return nil, fmt.Errorf("this model is a GPT2 one") - } - if _, ok := ml.gptstablelmmodels[modelName]; ok { - log.Debug().Msgf("Model is GPTStableLM: %s", modelName) - return nil, fmt.Errorf("this model is a GPTStableLM one") - } - // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.ModelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile)