From 1c4fbaae20069d3d96c311be0e02e4789b260ced Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 20 Apr 2023 19:33:36 +0200 Subject: [PATCH] Add support for cerebras (#45) Signed-off-by: mudler --- Makefile | 39 +++++++++++++++++++++++++------ README.md | 5 +--- api/api.go | 38 ++++++++++++++++++++++++++---- go.mod | 1 + go.sum | 4 ++++ pkg/model/loader.go | 57 +++++++++++++++++++++++++++++++++++++++++---- 6 files changed, 124 insertions(+), 20 deletions(-) diff --git a/Makefile b/Makefile index b6727935..0495a9ca 100644 --- a/Makefile +++ b/Makefile @@ -17,11 +17,12 @@ all: help ## Build: build: prepare ## Build the project - C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./ + C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp $(GOCMD) build -o $(BINARY_NAME) ./ buildgeneric: prepare-generic ## Build the project - C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j $(GOCMD) build -o $(BINARY_NAME) ./ + C_INCLUDE_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp LIBRARY_PATH=$(shell pwd)/go-llama.cpp:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2.cpp $(GOCMD) build -o $(BINARY_NAME) ./ +## GPT4ALL-J go-gpt4all-j: git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. @@ -30,6 +31,9 @@ go-gpt4all-j: @find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} + @find ./go-gpt4all-j -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} + + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} + + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} + + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} + go-gpt4all-j/libgptj.a: go-gpt4all-j $(MAKE) -C go-gpt4all-j libgptj.a @@ -37,6 +41,25 @@ go-gpt4all-j/libgptj.a: go-gpt4all-j go-gpt4all-j/libgptj.a-generic: go-gpt4all-j $(MAKE) -C go-gpt4all-j generic-libgptj.a +# CEREBRAS GPT +go-gpt2.cpp: + git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2.cpp +# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. + @find ./go-gpt2.cpp -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + + @find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + + @find ./go-gpt2.cpp -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + + @find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} + + @find ./go-gpt2.cpp -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt2_/g' {} + + @find ./go-gpt2.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} + + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gpt2_replace/g' {} + + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gpt2_replace/g' {} + + +go-gpt2.cpp/libgpt2.a: go-gpt2.cpp + $(MAKE) -C go-gpt2.cpp libgpt2.a + +go-gpt2.cpp/libgpt2.a-generic: go-gpt2.cpp + $(MAKE) -C go-gpt2.cpp generic-libgpt2.a + go-llama: git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama $(MAKE) -C go-llama libbinding.a @@ -45,17 +68,19 @@ go-llama-generic: git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama $(MAKE) -C go-llama generic-libbinding.a -prepare: go-llama go-gpt4all-j/libgptj.a +replace: $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j + $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2.cpp + +prepare: go-llama go-gpt4all-j/libgptj.a go-gpt2.cpp/libgpt2.a replace + +prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic go-gpt2.cpp/libgpt2.a-generic replace -prepare-generic: go-llama-generic go-gpt4all-j/libgptj.a-generic - $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama - $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j - clean: ## Remove build related file rm -fr ./go-llama rm -rf ./go-gpt4all-j + rm -rf ./go-gpt2.cpp rm -rf $(BINARY_NAME) ## Run: diff --git a/README.md b/README.md index 6cc217b3..c74b0d50 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,12 @@ LocalAI is a straightforward, drop-in replacement API compatible with OpenAI for - OpenAI compatible API - Supports multiple-models - Once loaded the first time, it keep models loaded in memory for faster inference -- Provides a simple command line interface that allows text generation directly from the terminal - Support for prompt templates - Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp). ## Model compatibility -It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) and also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all). +It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) supports also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all) and [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml). Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`. @@ -97,8 +96,6 @@ And you'll see: └───────────────────────────────────────────────────┘ ``` -Note: Models have to end up with `.bin` so can be listed by the `/models` endpoint. - You can control the API server options with command line arguments: ``` diff --git a/api/api.go b/api/api.go index 8c6fc8df..3ec1d813 100644 --- a/api/api.go +++ b/api/api.go @@ -8,6 +8,7 @@ import ( "sync" model "github.com/go-skynet/LocalAI/pkg/model" + gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" llama "github.com/go-skynet/go-llama.cpp" "github.com/gofiber/fiber/v2" @@ -73,6 +74,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 var err error var model *llama.LLama var gptModel *gptj.GPTJ + var gpt2Model *gpt2.GPT2 input := new(OpenAIRequest) // Get input data from the request body @@ -97,7 +99,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 } // Try to load the model with both - var llamaerr error + var llamaerr, gpt2err, gptjerr error llamaOpts := []llama.ModelOption{} if ctx != 0 { llamaOpts = append(llamaOpts, llama.SetContext(ctx)) @@ -106,11 +108,15 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 llamaOpts = append(llamaOpts, llama.EnableF16Memory) } + // TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation.. model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) if llamaerr != nil { - gptModel, err = loader.LoadGPTJModel(modelFile) - if err != nil { - return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors + gptModel, gptjerr = loader.LoadGPTJModel(modelFile) + if gptjerr != nil { + gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) + if gpt2err != nil { + return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors + } } } @@ -176,6 +182,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 var predFunc func() (string, error) switch { + case gpt2Model != nil: + predFunc = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(temperature), + gpt2.SetTopP(topP), + gpt2.SetTopK(topK), + gpt2.SetTokens(tokens), + gpt2.SetThreads(threads), + } + + if input.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch)) + } + + if input.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed)) + } + + return gpt2Model.Predict( + predInput, + predictOptions..., + ) + } case gptModel != nil: predFunc = func() (string, error) { // Generate the prediction using the language model diff --git a/go.mod b/go.mod index f7375d5a..77076157 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d // indirect github.com/google/uuid v1.3.0 // indirect github.com/klauspost/compress v1.15.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 41a2071b..b98a7d8f 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,10 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d h1:8crcrVuvpRzf6wejPtIFYGmMrSTfW94CYPJZIssT8zo= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d h1:Jabxk0NI5CLbY7PVODkRp1AQbEovS9gM6jGAOwyy5FI= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 09f57db4..1db1713b 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -12,20 +12,24 @@ import ( "github.com/rs/zerolog/log" + gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" llama "github.com/go-skynet/go-llama.cpp" ) type ModelLoader struct { - modelPath string - mu sync.Mutex - models map[string]*llama.LLama - gptmodels map[string]*gptj.GPTJ + modelPath string + mu sync.Mutex + + models map[string]*llama.LLama + gptmodels map[string]*gptj.GPTJ + gpt2models map[string]*gpt2.GPT2 + promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { - return &ModelLoader{modelPath: modelPath, gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} + return &ModelLoader{modelPath: modelPath, gpt2models: make(map[string]*gpt2.GPT2), gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} } func (ml *ModelLoader) ExistsInModelPath(s string) bool { @@ -98,6 +102,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { return nil } +func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + // Check if we already have a loaded model + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") + } + + if m, ok := ml.gpt2models[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil + } + + // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.modelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + + model, err := gpt2.New(modelFile) + if err != nil { + return nil, err + } + + // If there is a prompt template, load it + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { + return nil, err + } + + ml.gpt2models[modelName] = model + return model, err +} + func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { ml.mu.Lock() defer ml.mu.Unlock() @@ -112,6 +148,13 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { return m, nil } + // TODO: This needs refactoring, it's really bad to have it in here + // Check if we have a GPT2 model loaded instead - if we do we return an error so the API tries with GPT2 + if _, ok := ml.gpt2models[modelName]; ok { + log.Debug().Msgf("Model is GPT2: %s", modelName) + return nil, fmt.Errorf("this model is a GPT2 one") + } + // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.modelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) @@ -152,6 +195,10 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio log.Debug().Msgf("Model is GPTJ: %s", modelName) return nil, fmt.Errorf("this model is a GPTJ one") } + if _, ok := ml.gpt2models[modelName]; ok { + log.Debug().Msgf("Model is GPT2: %s", modelName) + return nil, fmt.Errorf("this model is a GPT2 one") + } // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.modelPath, modelName)