mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat: move llama to a grpc
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
b816009db0
commit
58f6aab637
9
Makefile
9
Makefile
@ -67,8 +67,8 @@ WHITE := $(shell tput -Txterm setaf 7)
|
||||
CYAN := $(shell tput -Txterm setaf 6)
|
||||
RESET := $(shell tput -Txterm sgr0)
|
||||
|
||||
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
|
||||
LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
|
||||
C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
|
||||
LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
|
||||
|
||||
ifeq ($(BUILD_TYPE),openblas)
|
||||
CGO_LDFLAGS+=-lopenblas
|
||||
@ -369,5 +369,8 @@ falcon-grpc: backend-assets/grpc
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \
|
||||
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/
|
||||
|
||||
llama-grpc: backend-assets/grpc
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \
|
||||
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/
|
||||
|
||||
grpcs: falcon-grpc
|
||||
grpcs: falcon-grpc llama-grpc
|
@ -18,7 +18,6 @@ import (
|
||||
"github.com/go-skynet/bloomz.cpp"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
|
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
||||
)
|
||||
@ -36,6 +35,11 @@ func gRPCModelOpts(c Config) *pb.ModelOptions {
|
||||
ContextSize: int32(c.ContextSize),
|
||||
Seed: int32(c.Seed),
|
||||
NBatch: int32(b),
|
||||
F16Memory: c.F16,
|
||||
MLock: c.MMlock,
|
||||
NUMA: c.NUMA,
|
||||
Embeddings: c.Embeddings,
|
||||
LowVRAM: c.LowVRAM,
|
||||
NGPULayers: int32(c.NGPULayers),
|
||||
MMap: c.MMap,
|
||||
MainGPU: c.MainGPU,
|
||||
@ -43,32 +47,6 @@ func gRPCModelOpts(c Config) *pb.ModelOptions {
|
||||
}
|
||||
}
|
||||
|
||||
// func defaultGGLLMOpts(c Config) []ggllm.ModelOption {
|
||||
// ggllmOpts := []ggllm.ModelOption{}
|
||||
// if c.ContextSize != 0 {
|
||||
// ggllmOpts = append(ggllmOpts, ggllm.SetContext(c.ContextSize))
|
||||
// }
|
||||
// // F16 doesn't seem to produce good output at all!
|
||||
// //if c.F16 {
|
||||
// // llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
// //}
|
||||
|
||||
// if c.NGPULayers != 0 {
|
||||
// ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(c.NGPULayers))
|
||||
// }
|
||||
|
||||
// ggllmOpts = append(ggllmOpts, ggllm.SetMMap(c.MMap))
|
||||
// ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(c.MainGPU))
|
||||
// ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(c.TensorSplit))
|
||||
// if c.Batch != 0 {
|
||||
// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(c.Batch))
|
||||
// } else {
|
||||
// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512))
|
||||
// }
|
||||
|
||||
// return ggllmOpts
|
||||
// }
|
||||
|
||||
func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions {
|
||||
promptCachePath := ""
|
||||
if c.PromptCachePath != "" {
|
||||
@ -77,14 +55,18 @@ func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions {
|
||||
promptCachePath = p
|
||||
}
|
||||
return &pb.PredictOptions{
|
||||
Temperature: float32(c.Temperature),
|
||||
TopP: float32(c.TopP),
|
||||
TopK: int32(c.TopK),
|
||||
Tokens: int32(c.Maxtokens),
|
||||
Threads: int32(c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
Temperature: float32(c.Temperature),
|
||||
TopP: float32(c.TopP),
|
||||
TopK: int32(c.TopK),
|
||||
Tokens: int32(c.Maxtokens),
|
||||
Threads: int32(c.Threads),
|
||||
PromptCacheAll: c.PromptCacheAll,
|
||||
PromptCacheRO: c.PromptCacheRO,
|
||||
PromptCachePath: promptCachePath,
|
||||
F16KV: c.F16,
|
||||
DebugMode: c.Debug,
|
||||
Grammar: c.Grammar,
|
||||
|
||||
Mirostat: int32(c.Mirostat),
|
||||
MirostatETA: float32(c.MirostatETA),
|
||||
MirostatTAU: float32(c.MirostatTAU),
|
||||
@ -105,200 +87,6 @@ func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions {
|
||||
}
|
||||
}
|
||||
|
||||
// func buildGGLLMPredictOptions(c Config, modelPath string) []ggllm.PredictOption {
|
||||
// // Generate the prediction using the language model
|
||||
// predictOptions := []ggllm.PredictOption{
|
||||
// ggllm.SetTemperature(c.Temperature),
|
||||
// ggllm.SetTopP(c.TopP),
|
||||
// ggllm.SetTopK(c.TopK),
|
||||
// ggllm.SetTokens(c.Maxtokens),
|
||||
// ggllm.SetThreads(c.Threads),
|
||||
// }
|
||||
|
||||
// if c.PromptCacheAll {
|
||||
// predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll)
|
||||
// }
|
||||
|
||||
// if c.PromptCacheRO {
|
||||
// predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO)
|
||||
// }
|
||||
|
||||
// if c.PromptCachePath != "" {
|
||||
// // Create parent directory
|
||||
// p := filepath.Join(modelPath, c.PromptCachePath)
|
||||
// os.MkdirAll(filepath.Dir(p), 0755)
|
||||
// predictOptions = append(predictOptions, ggllm.SetPathPromptCache(p))
|
||||
// }
|
||||
|
||||
// if c.Mirostat != 0 {
|
||||
// predictOptions = append(predictOptions, ggllm.SetMirostat(c.Mirostat))
|
||||
// }
|
||||
|
||||
// if c.MirostatETA != 0 {
|
||||
// predictOptions = append(predictOptions, ggllm.SetMirostatETA(c.MirostatETA))
|
||||
// }
|
||||
|
||||
// if c.MirostatTAU != 0 {
|
||||
// predictOptions = append(predictOptions, ggllm.SetMirostatTAU(c.MirostatTAU))
|
||||
// }
|
||||
|
||||
// if c.Debug {
|
||||
// predictOptions = append(predictOptions, ggllm.Debug)
|
||||
// }
|
||||
|
||||
// predictOptions = append(predictOptions, ggllm.SetStopWords(c.StopWords...))
|
||||
|
||||
// if c.RepeatPenalty != 0 {
|
||||
// predictOptions = append(predictOptions, ggllm.SetPenalty(c.RepeatPenalty))
|
||||
// }
|
||||
|
||||
// if c.Keep != 0 {
|
||||
// predictOptions = append(predictOptions, ggllm.SetNKeep(c.Keep))
|
||||
// }
|
||||
|
||||
// if c.Batch != 0 {
|
||||
// predictOptions = append(predictOptions, ggllm.SetBatch(c.Batch))
|
||||
// }
|
||||
|
||||
// if c.IgnoreEOS {
|
||||
// predictOptions = append(predictOptions, ggllm.IgnoreEOS)
|
||||
// }
|
||||
|
||||
// if c.Seed != 0 {
|
||||
// predictOptions = append(predictOptions, ggllm.SetSeed(c.Seed))
|
||||
// }
|
||||
|
||||
// //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
// predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(c.FrequencyPenalty))
|
||||
// predictOptions = append(predictOptions, ggllm.SetMlock(c.MMlock))
|
||||
// predictOptions = append(predictOptions, ggllm.SetMemoryMap(c.MMap))
|
||||
// predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(c.MainGPU))
|
||||
// predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(c.TensorSplit))
|
||||
// predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(c.TFZ))
|
||||
// predictOptions = append(predictOptions, ggllm.SetTypicalP(c.TypicalP))
|
||||
|
||||
// return predictOptions
|
||||
// }
|
||||
|
||||
func defaultLLamaOpts(c Config) []llama.ModelOption {
|
||||
llamaOpts := []llama.ModelOption{}
|
||||
if c.ContextSize != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize))
|
||||
}
|
||||
if c.F16 {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
if c.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
|
||||
if c.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit))
|
||||
if c.Batch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch))
|
||||
} else {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512))
|
||||
}
|
||||
|
||||
if c.NUMA {
|
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA)
|
||||
}
|
||||
|
||||
if c.LowVRAM {
|
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
return llamaOpts
|
||||
}
|
||||
|
||||
func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption {
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(c.Temperature),
|
||||
llama.SetTopP(c.TopP),
|
||||
llama.SetTopK(c.TopK),
|
||||
llama.SetTokens(c.Maxtokens),
|
||||
llama.SetThreads(c.Threads),
|
||||
}
|
||||
|
||||
if c.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if c.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar))
|
||||
|
||||
if c.PromptCachePath != "" {
|
||||
// Create parent directory
|
||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
||||
os.MkdirAll(filepath.Dir(p), 0755)
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(p))
|
||||
}
|
||||
|
||||
if c.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat))
|
||||
}
|
||||
|
||||
if c.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA))
|
||||
}
|
||||
|
||||
if c.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU))
|
||||
}
|
||||
|
||||
if c.Debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...))
|
||||
|
||||
if c.RepeatPenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty))
|
||||
}
|
||||
|
||||
if c.Keep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(c.Keep))
|
||||
}
|
||||
|
||||
if c.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
if c.F16 {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if c.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if c.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(c.Seed))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(c.MMlock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ))
|
||||
predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP))
|
||||
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) {
|
||||
if c.Backend != model.StableDiffusionBackend {
|
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
||||
@ -351,14 +139,12 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config,
|
||||
|
||||
modelFile := c.Model
|
||||
|
||||
llamaOpts := defaultLLamaOpts(c)
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
||||
var inferenceModel interface{}
|
||||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithLlamaOpts(llamaOpts...),
|
||||
model.WithLoadGRPCOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.assetsDestination),
|
||||
@ -377,14 +163,34 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config,
|
||||
|
||||
var fn func() ([]float32, error)
|
||||
switch model := inferenceModel.(type) {
|
||||
case *llama.LLama:
|
||||
case *grpc.Client:
|
||||
fn = func() ([]float32, error) {
|
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
|
||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
|
||||
if len(tokens) > 0 {
|
||||
return model.TokenEmbeddings(tokens, predictOptions...)
|
||||
embeds := []int32{}
|
||||
|
||||
for _, t := range tokens {
|
||||
embeds = append(embeds, int32(t))
|
||||
}
|
||||
predictOptions.EmbeddingTokens = embeds
|
||||
|
||||
res, err := model.Embeddings(context.TODO(), predictOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Embeddings, nil
|
||||
}
|
||||
return model.Embeddings(s, predictOptions...)
|
||||
predictOptions.Embeddings = s
|
||||
|
||||
res, err := model.Embeddings(context.TODO(), predictOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Embeddings, nil
|
||||
}
|
||||
|
||||
// bert embeddings
|
||||
case *bert.Bert:
|
||||
fn = func() ([]float32, error) {
|
||||
@ -432,14 +238,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to
|
||||
supportStreams := false
|
||||
modelFile := c.Model
|
||||
|
||||
llamaOpts := defaultLLamaOpts(c)
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
||||
var inferenceModel interface{}
|
||||
var err error
|
||||
|
||||
opts := []model.Option{
|
||||
model.WithLlamaOpts(llamaOpts...),
|
||||
model.WithLoadGRPCOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.assetsDestination),
|
||||
@ -708,26 +512,6 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to
|
||||
predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch))
|
||||
}
|
||||
|
||||
str, er := model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
)
|
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil)
|
||||
return str, er
|
||||
}
|
||||
case *llama.LLama:
|
||||
supportStreams = true
|
||||
fn = func() (string, error) {
|
||||
|
||||
if tokenCallback != nil {
|
||||
model.SetTokenCallback(tokenCallback)
|
||||
}
|
||||
|
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
|
||||
|
||||
str, er := model.Predict(
|
||||
s,
|
||||
predictOptions...,
|
||||
|
25
cmd/grpc/llama/main.go
Normal file
25
cmd/grpc/llama/main.go
Normal file
@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
@ -47,6 +47,17 @@ func (c *Client) HealthCheck(ctx context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) {
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewLLMClient(conn)
|
||||
|
||||
return client.Embedding(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) {
|
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
|
@ -8,4 +8,5 @@ type LLM interface {
|
||||
Predict(*pb.PredictOptions) (string, error)
|
||||
PredictStream(*pb.PredictOptions, chan string)
|
||||
Load(*pb.ModelOptions) error
|
||||
Embeddings(*pb.PredictOptions) ([]float32, error)
|
||||
}
|
||||
|
@ -42,6 +42,10 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption {
|
||||
predictOptions := []ggllm.PredictOption{
|
||||
ggllm.SetTemperature(float64(opts.Temperature)),
|
||||
|
165
pkg/grpc/llm/llama/llama.go
Normal file
165
pkg/grpc/llm/llama/llama.go
Normal file
@ -0,0 +1,165 @@
|
||||
package llama
|
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
type LLM struct {
|
||||
llama *llama.LLama
|
||||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
llamaOpts := []llama.ModelOption{}
|
||||
|
||||
if opts.ContextSize != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize)))
|
||||
}
|
||||
if opts.F16Memory {
|
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
}
|
||||
if opts.Embeddings {
|
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
|
||||
}
|
||||
if opts.NGPULayers != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
|
||||
}
|
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap))
|
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU))
|
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit))
|
||||
if opts.NBatch != 0 {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch)))
|
||||
} else {
|
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512))
|
||||
}
|
||||
|
||||
if opts.NUMA {
|
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA)
|
||||
}
|
||||
|
||||
if opts.LowVRAM {
|
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM)
|
||||
}
|
||||
|
||||
model, err := llama.New(opts.Model, llamaOpts...)
|
||||
llm.llama = model
|
||||
return err
|
||||
}
|
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption {
|
||||
predictOptions := []llama.PredictOption{
|
||||
llama.SetTemperature(float64(opts.Temperature)),
|
||||
llama.SetTopP(float64(opts.TopP)),
|
||||
llama.SetTopK(int(opts.TopK)),
|
||||
llama.SetTokens(int(opts.Tokens)),
|
||||
llama.SetThreads(int(opts.Threads)),
|
||||
}
|
||||
|
||||
if opts.PromptCacheAll {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
|
||||
}
|
||||
|
||||
if opts.PromptCacheRO {
|
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar))
|
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" {
|
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath))
|
||||
}
|
||||
|
||||
if opts.Mirostat != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat)))
|
||||
}
|
||||
|
||||
if opts.MirostatETA != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(float64(opts.MirostatETA)))
|
||||
}
|
||||
|
||||
if opts.MirostatTAU != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(float64(opts.MirostatTAU)))
|
||||
}
|
||||
|
||||
if opts.Debug {
|
||||
predictOptions = append(predictOptions, llama.Debug)
|
||||
}
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...))
|
||||
|
||||
if opts.PresencePenalty != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetPenalty(float64(opts.PresencePenalty)))
|
||||
}
|
||||
|
||||
if opts.NKeep != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep)))
|
||||
}
|
||||
|
||||
if opts.Batch != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch)))
|
||||
}
|
||||
|
||||
if opts.F16KV {
|
||||
predictOptions = append(predictOptions, llama.EnableF16KV)
|
||||
}
|
||||
|
||||
if opts.IgnoreEOS {
|
||||
predictOptions = append(predictOptions, llama.IgnoreEOS)
|
||||
}
|
||||
|
||||
if opts.Seed != 0 {
|
||||
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed)))
|
||||
}
|
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(float64(opts.FrequencyPenalty)))
|
||||
predictOptions = append(predictOptions, llama.SetMlock(opts.MLock))
|
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU))
|
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit))
|
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ)))
|
||||
predictOptions = append(predictOptions, llama.SetTypicalP(float64(opts.TypicalP)))
|
||||
return predictOptions
|
||||
}
|
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
|
||||
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
|
||||
}
|
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
|
||||
results <- token
|
||||
return true
|
||||
}))
|
||||
|
||||
go func() {
|
||||
_, err := llm.llama.Predict(opts.Prompt, predictOptions...)
|
||||
if err != nil {
|
||||
fmt.Println("err: ", err)
|
||||
}
|
||||
close(results)
|
||||
}()
|
||||
}
|
||||
|
||||
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
|
||||
predictOptions := buildPredictOptions(opts)
|
||||
|
||||
if len(opts.EmbeddingTokens) > 0 {
|
||||
tokens := []int{}
|
||||
for _, t := range opts.EmbeddingTokens {
|
||||
tokens = append(tokens, int(t))
|
||||
}
|
||||
return llm.llama.TokenEmbeddings(tokens, predictOptions...)
|
||||
}
|
||||
|
||||
return llm.llama.Embeddings(opts.Embeddings, predictOptions...)
|
||||
}
|
@ -87,7 +87,6 @@ type PredictOptions struct {
|
||||
MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"`
|
||||
PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"`
|
||||
LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"`
|
||||
PathPromptCache string `protobuf:"bytes,24,opt,name=PathPromptCache,proto3" json:"PathPromptCache,omitempty"`
|
||||
MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"`
|
||||
MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"`
|
||||
PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"`
|
||||
@ -98,6 +97,8 @@ type PredictOptions struct {
|
||||
TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"`
|
||||
PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"`
|
||||
Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"`
|
||||
EmbeddingTokens []int32 `protobuf:"varint,35,rep,packed,name=EmbeddingTokens,proto3" json:"EmbeddingTokens,omitempty"`
|
||||
Embeddings string `protobuf:"bytes,36,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"`
|
||||
}
|
||||
|
||||
func (x *PredictOptions) Reset() {
|
||||
@ -293,13 +294,6 @@ func (x *PredictOptions) GetLogitBias() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PredictOptions) GetPathPromptCache() string {
|
||||
if x != nil {
|
||||
return x.PathPromptCache
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PredictOptions) GetMLock() bool {
|
||||
if x != nil {
|
||||
return x.MLock
|
||||
@ -370,6 +364,20 @@ func (x *PredictOptions) GetDebug() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *PredictOptions) GetEmbeddingTokens() []int32 {
|
||||
if x != nil {
|
||||
return x.EmbeddingTokens
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *PredictOptions) GetEmbeddings() string {
|
||||
if x != nil {
|
||||
return x.Embeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
type Reply struct {
|
||||
state protoimpl.MessageState
|
||||
@ -624,13 +632,60 @@ func (x *Result) GetSuccess() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type EmbeddingResult struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Embeddings []float32 `protobuf:"fixed32,1,rep,packed,name=embeddings,proto3" json:"embeddings,omitempty"`
|
||||
}
|
||||
|
||||
func (x *EmbeddingResult) Reset() {
|
||||
*x = EmbeddingResult{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *EmbeddingResult) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*EmbeddingResult) ProtoMessage() {}
|
||||
|
||||
func (x *EmbeddingResult) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use EmbeddingResult.ProtoReflect.Descriptor instead.
|
||||
func (*EmbeddingResult) Descriptor() ([]byte, []int) {
|
||||
return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{5}
|
||||
}
|
||||
|
||||
func (x *EmbeddingResult) GetEmbeddings() []float32 {
|
||||
if x != nil {
|
||||
return x.Embeddings
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_pkg_grpc_proto_llmserver_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{
|
||||
0x0a, 0x1e, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x2f, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x12, 0x03, 0x6c, 0x6c, 0x6d, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d,
|
||||
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x80, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69,
|
||||
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xa0, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69,
|
||||
0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x72, 0x6f,
|
||||
0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70,
|
||||
0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52,
|
||||
@ -673,28 +728,30 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{
|
||||
0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, 0x16, 0x20,
|
||||
0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x12,
|
||||
0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x28, 0x0a,
|
||||
0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65,
|
||||
0x18, 0x18, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d,
|
||||
0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b,
|
||||
0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a,
|
||||
0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61,
|
||||
0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65,
|
||||
0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70,
|
||||
0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12, 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f,
|
||||
0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08,
|
||||
0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12,
|
||||
0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69,
|
||||
0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e,
|
||||
0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c,
|
||||
0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72,
|
||||
0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20,
|
||||
0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f,
|
||||
0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x18, 0x21, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50,
|
||||
0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x18, 0x22, 0x20, 0x01,
|
||||
0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70,
|
||||
0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x14, 0x0a,
|
||||
0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c,
|
||||
0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28,
|
||||
0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70,
|
||||
0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52,
|
||||
0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12,
|
||||
0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f,
|
||||
0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61,
|
||||
0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72,
|
||||
0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12,
|
||||
0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e,
|
||||
0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b,
|
||||
0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54,
|
||||
0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12,
|
||||
0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61,
|
||||
0x74, 0x68, 0x18, 0x21, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74,
|
||||
0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62,
|
||||
0x75, 0x67, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x12,
|
||||
0x28, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65,
|
||||
0x6e, 0x73, 0x18, 0x23, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64,
|
||||
0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62,
|
||||
0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x45,
|
||||
0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70,
|
||||
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x82, 0x03, 0x0a,
|
||||
0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a,
|
||||
@ -724,26 +781,33 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{
|
||||
0x74, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d,
|
||||
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65,
|
||||
0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32,
|
||||
0xc4, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74,
|
||||
0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65,
|
||||
0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c,
|
||||
0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x13,
|
||||
0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22,
|
||||
0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x11,
|
||||
0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00,
|
||||
0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61,
|
||||
0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22,
|
||||
0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75,
|
||||
0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73,
|
||||
0x18, 0x01, 0x20, 0x03, 0x28, 0x02, 0x52, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e,
|
||||
0x67, 0x73, 0x32, 0xfe, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65,
|
||||
0x61, 0x6c, 0x74, 0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74,
|
||||
0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52,
|
||||
0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63,
|
||||
0x74, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f,
|
||||
0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70,
|
||||
0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79,
|
||||
0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
|
||||
0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67,
|
||||
0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49,
|
||||
0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
|
||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x6c, 0x79, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65,
|
||||
0x6c, 0x12, 0x11, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74,
|
||||
0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c,
|
||||
0x74, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74,
|
||||
0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69,
|
||||
0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e,
|
||||
0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x38, 0x0a, 0x09, 0x45, 0x6d, 0x62,
|
||||
0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65,
|
||||
0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x6c, 0x6c,
|
||||
0x6d, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c,
|
||||
0x74, 0x22, 0x00, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74,
|
||||
0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76,
|
||||
0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x01, 0x5a,
|
||||
0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73,
|
||||
0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b,
|
||||
0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@ -758,25 +822,28 @@ func file_pkg_grpc_proto_llmserver_proto_rawDescGZIP() []byte {
|
||||
return file_pkg_grpc_proto_llmserver_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_pkg_grpc_proto_llmserver_proto_goTypes = []interface{}{
|
||||
(*HealthMessage)(nil), // 0: llm.HealthMessage
|
||||
(*PredictOptions)(nil), // 1: llm.PredictOptions
|
||||
(*Reply)(nil), // 2: llm.Reply
|
||||
(*ModelOptions)(nil), // 3: llm.ModelOptions
|
||||
(*Result)(nil), // 4: llm.Result
|
||||
(*HealthMessage)(nil), // 0: llm.HealthMessage
|
||||
(*PredictOptions)(nil), // 1: llm.PredictOptions
|
||||
(*Reply)(nil), // 2: llm.Reply
|
||||
(*ModelOptions)(nil), // 3: llm.ModelOptions
|
||||
(*Result)(nil), // 4: llm.Result
|
||||
(*EmbeddingResult)(nil), // 5: llm.EmbeddingResult
|
||||
}
|
||||
var file_pkg_grpc_proto_llmserver_proto_depIdxs = []int32{
|
||||
0, // 0: llm.LLM.Health:input_type -> llm.HealthMessage
|
||||
1, // 1: llm.LLM.Predict:input_type -> llm.PredictOptions
|
||||
3, // 2: llm.LLM.LoadModel:input_type -> llm.ModelOptions
|
||||
1, // 3: llm.LLM.PredictStream:input_type -> llm.PredictOptions
|
||||
2, // 4: llm.LLM.Health:output_type -> llm.Reply
|
||||
2, // 5: llm.LLM.Predict:output_type -> llm.Reply
|
||||
4, // 6: llm.LLM.LoadModel:output_type -> llm.Result
|
||||
2, // 7: llm.LLM.PredictStream:output_type -> llm.Reply
|
||||
4, // [4:8] is the sub-list for method output_type
|
||||
0, // [0:4] is the sub-list for method input_type
|
||||
1, // 4: llm.LLM.Embedding:input_type -> llm.PredictOptions
|
||||
2, // 5: llm.LLM.Health:output_type -> llm.Reply
|
||||
2, // 6: llm.LLM.Predict:output_type -> llm.Reply
|
||||
4, // 7: llm.LLM.LoadModel:output_type -> llm.Result
|
||||
2, // 8: llm.LLM.PredictStream:output_type -> llm.Reply
|
||||
5, // 9: llm.LLM.Embedding:output_type -> llm.EmbeddingResult
|
||||
5, // [5:10] is the sub-list for method output_type
|
||||
0, // [0:5] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
@ -848,6 +915,18 @@ func file_pkg_grpc_proto_llmserver_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_pkg_grpc_proto_llmserver_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*EmbeddingResult); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
@ -855,7 +934,7 @@ func file_pkg_grpc_proto_llmserver_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_pkg_grpc_proto_llmserver_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 5,
|
||||
NumMessages: 6,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
@ -12,6 +12,7 @@ service LLM {
|
||||
rpc Predict(PredictOptions) returns (Reply) {}
|
||||
rpc LoadModel(ModelOptions) returns (Result) {}
|
||||
rpc PredictStream(PredictOptions) returns (stream Reply) {}
|
||||
rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
|
||||
}
|
||||
|
||||
message HealthMessage {}
|
||||
@ -41,7 +42,6 @@ message PredictOptions {
|
||||
float MirostatTAU = 21;
|
||||
bool PenalizeNL = 22;
|
||||
string LogitBias = 23;
|
||||
string PathPromptCache = 24;
|
||||
bool MLock = 25;
|
||||
bool MMap = 26;
|
||||
bool PromptCacheAll = 27;
|
||||
@ -52,6 +52,8 @@ message PredictOptions {
|
||||
float TopP = 32;
|
||||
string PromptCachePath = 33;
|
||||
bool Debug = 34;
|
||||
repeated int32 EmbeddingTokens = 35;
|
||||
string Embeddings = 36;
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
@ -79,4 +81,8 @@ message ModelOptions {
|
||||
message Result {
|
||||
string message = 1;
|
||||
bool success = 2;
|
||||
}
|
||||
|
||||
message EmbeddingResult {
|
||||
repeated float embeddings = 1;
|
||||
}
|
@ -26,6 +26,7 @@ type LLMClient interface {
|
||||
Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error)
|
||||
LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error)
|
||||
PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error)
|
||||
Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error)
|
||||
}
|
||||
|
||||
type lLMClient struct {
|
||||
@ -95,6 +96,15 @@ func (x *lLMPredictStreamClient) Recv() (*Reply, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *lLMClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
|
||||
out := new(EmbeddingResult)
|
||||
err := c.cc.Invoke(ctx, "/llm.LLM/Embedding", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LLMServer is the server API for LLM service.
|
||||
// All implementations must embed UnimplementedLLMServer
|
||||
// for forward compatibility
|
||||
@ -103,6 +113,7 @@ type LLMServer interface {
|
||||
Predict(context.Context, *PredictOptions) (*Reply, error)
|
||||
LoadModel(context.Context, *ModelOptions) (*Result, error)
|
||||
PredictStream(*PredictOptions, LLM_PredictStreamServer) error
|
||||
Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error)
|
||||
mustEmbedUnimplementedLLMServer()
|
||||
}
|
||||
|
||||
@ -122,6 +133,9 @@ func (UnimplementedLLMServer) LoadModel(context.Context, *ModelOptions) (*Result
|
||||
func (UnimplementedLLMServer) PredictStream(*PredictOptions, LLM_PredictStreamServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method PredictStream not implemented")
|
||||
}
|
||||
func (UnimplementedLLMServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented")
|
||||
}
|
||||
func (UnimplementedLLMServer) mustEmbedUnimplementedLLMServer() {}
|
||||
|
||||
// UnsafeLLMServer may be embedded to opt out of forward compatibility for this service.
|
||||
@ -210,6 +224,24 @@ func (x *lLMPredictStreamServer) Send(m *Reply) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _LLM_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(PredictOptions)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LLMServer).Embedding(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/llm.LLM/Embedding",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LLMServer).Embedding(ctx, req.(*PredictOptions))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// LLM_ServiceDesc is the grpc.ServiceDesc for LLM service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@ -229,6 +261,10 @@ var LLM_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "LoadModel",
|
||||
Handler: _LLM_LoadModel_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Embedding",
|
||||
Handler: _LLM_Embedding_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
@ -29,6 +29,15 @@ func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, e
|
||||
return &pb.Reply{Message: "OK"}, nil
|
||||
}
|
||||
|
||||
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) {
|
||||
embeds, err := s.llm.Embeddings(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pb.EmbeddingResult{Embeddings: embeds}, nil
|
||||
}
|
||||
|
||||
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) {
|
||||
err := s.llm.Load(in)
|
||||
if err != nil {
|
||||
|
@ -17,7 +17,6 @@ import (
|
||||
bloomz "github.com/go-skynet/bloomz.cpp"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hpcloud/tail"
|
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
||||
@ -135,11 +134,11 @@ var lcHuggingFace = func(repoId string) (interface{}, error) {
|
||||
return langchain.NewHuggingFace(repoId)
|
||||
}
|
||||
|
||||
func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
|
||||
return func(s string) (interface{}, error) {
|
||||
return llama.New(s, opts...)
|
||||
}
|
||||
}
|
||||
// func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
|
||||
// return func(s string) (interface{}, error) {
|
||||
// return llama.New(s, opts...)
|
||||
// }
|
||||
// }
|
||||
|
||||
func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) {
|
||||
return func(s string) (interface{}, error) {
|
||||
@ -263,7 +262,8 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err err
|
||||
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
|
||||
switch strings.ToLower(o.backendString) {
|
||||
case LlamaBackend:
|
||||
return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...))
|
||||
// return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...))
|
||||
return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o))
|
||||
case BloomzBackend:
|
||||
return ml.LoadModel(o.modelFile, bloomzLM)
|
||||
case GPTJBackend:
|
||||
@ -325,7 +325,6 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) {
|
||||
model, modelerr := ml.BackendLoader(
|
||||
WithBackendString(b),
|
||||
WithModelFile(o.modelFile),
|
||||
WithLlamaOpts(o.llamaOpts...),
|
||||
WithLoadGRPCOpts(o.gRPCOptions),
|
||||
WithThreads(o.threads),
|
||||
WithAssetDir(o.assetDir),
|
||||
|
@ -2,13 +2,11 @@ package model
|
||||
|
||||
import (
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
llama "github.com/go-skynet/go-llama.cpp"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
backendString string
|
||||
modelFile string
|
||||
llamaOpts []llama.ModelOption
|
||||
threads uint32
|
||||
assetDir string
|
||||
|
||||
@ -35,12 +33,6 @@ func WithLoadGRPCOpts(opts *pb.ModelOptions) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithLlamaOpts(opts ...llama.ModelOption) Option {
|
||||
return func(o *Options) {
|
||||
o.llamaOpts = append(o.llamaOpts, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithThreads(threads uint32) Option {
|
||||
return func(o *Options) {
|
||||
o.threads = threads
|
||||
|
Loading…
Reference in New Issue
Block a user