diff --git a/api/api_test.go b/api/api_test.go index 2ffcb71e..a602229d 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -389,70 +389,6 @@ var _ = Describe("API test", func() { }) }) - Context("External gRPCs", func() { - BeforeEach(func() { - modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - c, cancel = context.WithCancel(context.Background()) - - app, err := App( - append(commonOpts, - options.WithContext(c), - options.WithAudioDir(tmpdir), - options.WithImageDir(tmpdir), - options.WithModelLoader(modelLoader), - options.WithBackendAssets(backendAssets), - options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), - options.WithBackendAssetsOutput(tmpdir))..., - ) - Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9090") - - defaultConfig := openai.DefaultConfig("") - defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" - - // Wait for API to be ready - client = openai.NewClientWithConfig(defaultConfig) - Eventually(func() error { - _, err := client.ListModels(context.TODO()) - return err - }, "2m").ShouldNot(HaveOccurred()) - }) - - AfterEach(func() { - cancel() - app.Shutdown() - os.RemoveAll(tmpdir) - }) - - It("calculate embeddings with huggingface", func() { - if runtime.GOOS != "linux" { - Skip("test supported only on linux") - } - resp, err := client.CreateEmbeddings( - context.Background(), - openai.EmbeddingRequest{ - Model: openai.AdaCodeSearchCode, - Input: []string{"sun", "cat"}, - }, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) - Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) - - sunEmbedding := resp.Data[0].Embedding - resp2, err := client.CreateEmbeddings( - context.Background(), - openai.EmbeddingRequest{ - Model: openai.AdaCodeSearchCode, - Input: []string{"sun"}, - }, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) - Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) - }) - }) - Context("Model gallery", func() { BeforeEach(func() { var err error @@ -573,7 +509,10 @@ var _ = Describe("API test", func() { var err error app, err = App( append(commonOpts, - options.WithContext(c), options.WithModelLoader(modelLoader))...) + options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), + options.WithContext(c), + options.WithModelLoader(modelLoader), + )...) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -628,7 +567,7 @@ var _ = Describe("API test", func() { }) It("returns errors", func() { - backends := len(model.AutoLoadBackends) + backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends))) @@ -675,6 +614,36 @@ var _ = Describe("API test", func() { Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) }) + Context("External gRPC calls", func() { + It("calculate embeddings with huggingface", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun", "cat"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) + Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) + + sunEmbedding := resp.Data[0].Embedding + resp2, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) + Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) + }) + }) + Context("backends", func() { It("runs rwkv completion", func() { if runtime.GOOS != "linux" { diff --git a/api/backend/llm.go b/api/backend/llm.go index 593eea3c..23a5ca4c 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -73,6 +73,9 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt return ss, err } else { reply, err := inferenceModel.Predict(o.Context, opts) + if err != nil { + return "", err + } return reply.Message, err } } diff --git a/pkg/grpc/tts/piper.go b/pkg/grpc/tts/piper.go index dbaa4b73..3bc85e0d 100644 --- a/pkg/grpc/tts/piper.go +++ b/pkg/grpc/tts/piper.go @@ -3,7 +3,9 @@ package tts // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( + "fmt" "os" + "path/filepath" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" @@ -16,6 +18,9 @@ type Piper struct { } func (sd *Piper) Load(opts *pb.ModelOptions) error { + if filepath.Ext(opts.Model) != ".onnx" { + return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.Model) + } var err error // Note: the Model here is a path to a directory containing the model files sd.piper, err = New(opts.LibrarySearchPath) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 32c9afc6..08bf6c4d 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -206,10 +206,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc res, err := client.LoadModel(o.context, &options) if err != nil { - return nil, err + return nil, fmt.Errorf("could not load model: %w", err) } if !res.Success { - return nil, fmt.Errorf("could not load model: %s", res.Message) + return nil, fmt.Errorf("could not load model (no success): %s", res.Message) } return client, nil @@ -289,7 +289,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { err = multierror.Append(err, modelerr) log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error()) } else if model == nil { - err = multierror.Append(err, modelerr) + err = multierror.Append(err, fmt.Errorf("backend returned no usable model")) log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model") } }