From f2f1d7fe72c8205f3740c41de53b4b868f5d72cf Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH] feat: use gRPC for transformers Signed-off-by: Ettore Di Giacinto --- Makefile | 45 ++++-- api/prediction.go | 194 +------------------------ cmd/grpc/dolly/main.go | 23 +++ cmd/grpc/gpt2/main.go | 23 +++ cmd/grpc/gptj/main.go | 23 +++ cmd/grpc/gptneox/main.go | 23 +++ cmd/grpc/mpt/main.go | 23 +++ cmd/grpc/replit/main.go | 23 +++ cmd/grpc/starcoder/main.go | 23 +++ pkg/grpc/llm/ggml/starcoder.go | 0 pkg/grpc/llm/transformers/dolly.go | 42 ++++++ pkg/grpc/llm/transformers/gpt2.go | 42 ++++++ pkg/grpc/llm/transformers/gptj.go | 42 ++++++ pkg/grpc/llm/transformers/gptneox.go | 42 ++++++ pkg/grpc/llm/transformers/mpt.go | 42 ++++++ pkg/grpc/llm/transformers/predict.go | 26 ++++ pkg/grpc/llm/transformers/replit.go | 42 ++++++ pkg/grpc/llm/transformers/starcoder.go | 42 ++++++ pkg/model/initializers.go | 56 +------ 19 files changed, 518 insertions(+), 258 deletions(-) create mode 100644 cmd/grpc/dolly/main.go create mode 100644 cmd/grpc/gpt2/main.go create mode 100644 cmd/grpc/gptj/main.go create mode 100644 cmd/grpc/gptneox/main.go create mode 100644 cmd/grpc/mpt/main.go create mode 100644 cmd/grpc/replit/main.go create mode 100644 cmd/grpc/starcoder/main.go delete mode 100644 pkg/grpc/llm/ggml/starcoder.go create mode 100644 pkg/grpc/llm/transformers/dolly.go create mode 100644 pkg/grpc/llm/transformers/gpt2.go create mode 100644 pkg/grpc/llm/transformers/gptj.go create mode 100644 pkg/grpc/llm/transformers/gptneox.go create mode 100644 pkg/grpc/llm/transformers/mpt.go create mode 100644 pkg/grpc/llm/transformers/predict.go create mode 100644 pkg/grpc/llm/transformers/replit.go create mode 100644 pkg/grpc/llm/transformers/starcoder.go diff --git a/Makefile b/Makefile index df7a16e4..610cc6f7 100644 --- a/Makefile +++ b/Makefile @@ -189,21 +189,6 @@ gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all go-ggml-transformers: git clone --recurse-submodules https://github.com/go-skynet/go-ggml-transformers.cpp go-ggml-transformers cd go-ggml-transformers && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1 - # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/set_numa_thread_affinity/transformers_set_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/set_numa_thread_affinity/transformers_set_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/clear_numa_thread_affinity/transformers_clear_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/clear_numa_thread_affinity/transformers_clear_numa_thread_affinity/g' {} + go-ggml-transformers/libtransformers.a: go-ggml-transformers $(MAKE) -C go-ggml-transformers libtransformers.a @@ -359,4 +344,32 @@ gpt4all-grpc: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-binding CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \ $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/ -grpcs: falcon-grpc llama-grpc gpt4all-grpc \ No newline at end of file +dolly-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/ + +gpt2-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/ + +gptj-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/ + +gptneox-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/ + +mpt-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/ + +replit-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/ + +starcoder-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/ + +grpcs: falcon-grpc llama-grpc gpt4all-grpc dolly-grpc gpt2-grpc gptj-grpc gptneox-grpc mpt-grpc replit-grpc starcoder-grpc \ No newline at end of file diff --git a/api/prediction.go b/api/prediction.go index f24376ca..4a9c1c84 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -17,7 +17,6 @@ import ( "github.com/go-skynet/LocalAI/pkg/stablediffusion" "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" - transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 @@ -244,7 +243,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to opts := []model.Option{ model.WithLoadGRPCOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), + model.WithThreads(uint32(c.Threads)), // GPT4all uses this model.WithAssetDir(o.assetsDestination), model.WithModelFile(modelFile), } @@ -279,102 +278,6 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to return response, nil } - case *transformers.GPTNeoX: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Replit: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Starcoder: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.MPT: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } case *bloomz.Bloomz: fn = func() (string, error) { // Generate the prediction using the language model @@ -395,102 +298,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to predictOptions..., ) } - case *transformers.Falcon: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.GPTJ: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Dolly: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.GPT2: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } case *grpc.Client: // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported supportStreams = true diff --git a/cmd/grpc/dolly/main.go b/cmd/grpc/dolly/main.go new file mode 100644 index 00000000..43bba92f --- /dev/null +++ b/cmd/grpc/dolly/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Dolly{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gpt2/main.go b/cmd/grpc/gpt2/main.go new file mode 100644 index 00000000..d9fe2752 --- /dev/null +++ b/cmd/grpc/gpt2/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPT2{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gptj/main.go b/cmd/grpc/gptj/main.go new file mode 100644 index 00000000..27d82104 --- /dev/null +++ b/cmd/grpc/gptj/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPTJ{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gptneox/main.go b/cmd/grpc/gptneox/main.go new file mode 100644 index 00000000..3d005ca8 --- /dev/null +++ b/cmd/grpc/gptneox/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPTNeoX{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/mpt/main.go b/cmd/grpc/mpt/main.go new file mode 100644 index 00000000..58456a7d --- /dev/null +++ b/cmd/grpc/mpt/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.MPT{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/replit/main.go b/cmd/grpc/replit/main.go new file mode 100644 index 00000000..aed67fbc --- /dev/null +++ b/cmd/grpc/replit/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Replit{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/starcoder/main.go b/cmd/grpc/starcoder/main.go new file mode 100644 index 00000000..2847acf7 --- /dev/null +++ b/cmd/grpc/starcoder/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Starcoder{}); err != nil { + panic(err) + } +} diff --git a/pkg/grpc/llm/ggml/starcoder.go b/pkg/grpc/llm/ggml/starcoder.go deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/grpc/llm/transformers/dolly.go b/pkg/grpc/llm/transformers/dolly.go new file mode 100644 index 00000000..28a44a7a --- /dev/null +++ b/pkg/grpc/llm/transformers/dolly.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Dolly struct { + dolly *transformers.Dolly +} + +func (llm *Dolly) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewDolly(opts.Model) + llm.dolly = model + return err +} + +func (llm *Dolly) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { + return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/gpt2.go b/pkg/grpc/llm/transformers/gpt2.go new file mode 100644 index 00000000..0eaf7876 --- /dev/null +++ b/pkg/grpc/llm/transformers/gpt2.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPT2 struct { + gpt2 *transformers.GPT2 +} + +func (llm *GPT2) Load(opts *pb.ModelOptions) error { + model, err := transformers.New(opts.Model) + llm.gpt2 = model + return err +} + +func (llm *GPT2) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/gptj.go b/pkg/grpc/llm/transformers/gptj.go new file mode 100644 index 00000000..a7138ef4 --- /dev/null +++ b/pkg/grpc/llm/transformers/gptj.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPTJ struct { + gptj *transformers.GPTJ +} + +func (llm *GPTJ) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewGPTJ(opts.Model) + llm.gptj = model + return err +} + +func (llm *GPTJ) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/gptneox.go b/pkg/grpc/llm/transformers/gptneox.go new file mode 100644 index 00000000..2edf4ba8 --- /dev/null +++ b/pkg/grpc/llm/transformers/gptneox.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPTNeoX struct { + gptneox *transformers.GPTNeoX +} + +func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewGPTNeoX(opts.Model) + llm.gptneox = model + return err +} + +func (llm *GPTNeoX) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/mpt.go b/pkg/grpc/llm/transformers/mpt.go new file mode 100644 index 00000000..ab88418f --- /dev/null +++ b/pkg/grpc/llm/transformers/mpt.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type MPT struct { + mpt *transformers.MPT +} + +func (llm *MPT) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewMPT(opts.Model) + llm.mpt = model + return err +} + +func (llm *MPT) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { + return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/predict.go b/pkg/grpc/llm/transformers/predict.go new file mode 100644 index 00000000..861d1196 --- /dev/null +++ b/pkg/grpc/llm/transformers/predict.go @@ -0,0 +1,26 @@ +package transformers + +import ( + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +func buildPredictOptions(opts *pb.PredictOptions) []transformers.PredictOption { + predictOptions := []transformers.PredictOption{ + transformers.SetTemperature(float64(opts.Temperature)), + transformers.SetTopP(float64(opts.TopP)), + transformers.SetTopK(int(opts.TopK)), + transformers.SetTokens(int(opts.Tokens)), + transformers.SetThreads(int(opts.Threads)), + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, transformers.SetBatch(int(opts.Batch))) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, transformers.SetSeed(int(opts.Seed))) + } + + return predictOptions +} diff --git a/pkg/grpc/llm/transformers/replit.go b/pkg/grpc/llm/transformers/replit.go new file mode 100644 index 00000000..ca1d66f6 --- /dev/null +++ b/pkg/grpc/llm/transformers/replit.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Replit struct { + replit *transformers.Replit +} + +func (llm *Replit) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewReplit(opts.Model) + llm.replit = model + return err +} + +func (llm *Replit) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { + return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/starcoder.go b/pkg/grpc/llm/transformers/starcoder.go new file mode 100644 index 00000000..6e1a94bc --- /dev/null +++ b/pkg/grpc/llm/transformers/starcoder.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Starcoder struct { + starcoder *transformers.Starcoder +} + +func (llm *Starcoder) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewStarcoder(opts.Model) + llm.starcoder = model + return err +} + +func (llm *Starcoder) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { + return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 3a0c5eaa..44a06384 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -16,7 +16,6 @@ import ( "github.com/go-skynet/LocalAI/pkg/tts" bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" - transformers "github.com/go-skynet/go-ggml-transformers.cpp" "github.com/hashicorp/go-multierror" "github.com/hpcloud/tail" "github.com/phayes/freeport" @@ -55,7 +54,6 @@ var autoLoadBackends []string = []string{ LlamaBackend, Gpt4All, RwkvBackend, - //GGLLMFalconBackend, WhisperBackend, BertEmbeddingsBackend, GPTNeoXBackend, @@ -69,40 +67,6 @@ var autoLoadBackends []string = []string{ BloomzBackend, } -var starCoder = func(modelFile string) (interface{}, error) { - return transformers.NewStarcoder(modelFile) -} - -var mpt = func(modelFile string) (interface{}, error) { - return transformers.NewMPT(modelFile) -} - -var dolly = func(modelFile string) (interface{}, error) { - return transformers.NewDolly(modelFile) -} - -// func ggllmFalcon(opts ...ggllm.ModelOption) func(string) (interface{}, error) { -// return func(s string) (interface{}, error) { -// return ggllm.New(s, opts...) -// } -// } - -var gptNeoX = func(modelFile string) (interface{}, error) { - return transformers.NewGPTNeoX(modelFile) -} - -var replit = func(modelFile string) (interface{}, error) { - return transformers.NewReplit(modelFile) -} - -var gptJ = func(modelFile string) (interface{}, error) { - return transformers.NewGPTJ(modelFile) -} - -var falcon = func(modelFile string) (interface{}, error) { - return transformers.NewFalcon(modelFile) -} - var bertEmbeddings = func(modelFile string) (interface{}, error) { return bert.New(modelFile) } @@ -111,10 +75,6 @@ var bloomzLM = func(modelFile string) (interface{}, error) { return bloomz.New(modelFile) } -var transformersLM = func(modelFile string) (interface{}, error) { - return transformers.New(modelFile) -} - var stableDiffusion = func(assetDir string) (interface{}, error) { return stablediffusion.New(assetDir) } @@ -261,34 +221,32 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err err log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) switch strings.ToLower(o.backendString) { case LlamaBackend: - // return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...)) return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o)) case BloomzBackend: return ml.LoadModel(o.modelFile, bloomzLM) case GPTJBackend: - return ml.LoadModel(o.modelFile, gptJ) + return ml.LoadModel(o.modelFile, ml.grpcModel(GPTJBackend, o)) case DollyBackend: - return ml.LoadModel(o.modelFile, dolly) + return ml.LoadModel(o.modelFile, ml.grpcModel(DollyBackend, o)) case MPTBackend: - return ml.LoadModel(o.modelFile, mpt) + return ml.LoadModel(o.modelFile, ml.grpcModel(MPTBackend, o)) case Gpt2Backend: - return ml.LoadModel(o.modelFile, transformersLM) + return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt2Backend, o)) case FalconBackend: return ml.LoadModel(o.modelFile, ml.grpcModel(FalconBackend, o)) case GPTNeoXBackend: - return ml.LoadModel(o.modelFile, gptNeoX) + return ml.LoadModel(o.modelFile, ml.grpcModel(GPTNeoXBackend, o)) case ReplitBackend: - return ml.LoadModel(o.modelFile, replit) + return ml.LoadModel(o.modelFile, ml.grpcModel(ReplitBackend, o)) case StableDiffusionBackend: return ml.LoadModel(o.modelFile, stableDiffusion) case PiperBackend: return ml.LoadModel(o.modelFile, piperTTS(filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data"))) case StarcoderBackend: - return ml.LoadModel(o.modelFile, starCoder) + return ml.LoadModel(o.modelFile, ml.grpcModel(StarcoderBackend, o)) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o)) - // return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all")))) case BertEmbeddingsBackend: return ml.LoadModel(o.modelFile, bertEmbeddings)