This commit is contained in:
Ettore Di Giacinto 2023-07-21 00:52:43 +02:00
parent e459f114cd
commit c71c729bc2
4 changed files with 46 additions and 69 deletions

View File

@ -389,70 +389,6 @@ var _ = Describe("API test", func() {
}) })
}) })
Context("External gRPCs", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background())
app, err := App(
append(commonOpts,
options.WithContext(c),
options.WithAudioDir(tmpdir),
options.WithImageDir(tmpdir),
options.WithModelLoader(modelLoader),
options.WithBackendAssets(backendAssets),
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
options.WithBackendAssetsOutput(tmpdir))...,
)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
})
It("calculate embeddings with huggingface", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
})
})
Context("Model gallery", func() { Context("Model gallery", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
@ -573,7 +509,10 @@ var _ = Describe("API test", func() {
var err error var err error
app, err = App( app, err = App(
append(commonOpts, append(commonOpts,
options.WithContext(c), options.WithModelLoader(modelLoader))...) options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
options.WithContext(c),
options.WithModelLoader(modelLoader),
)...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -628,7 +567,7 @@ var _ = Describe("API test", func() {
}) })
It("returns errors", func() { It("returns errors", func() {
backends := len(model.AutoLoadBackends) backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends))) Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends)))
@ -675,6 +614,36 @@ var _ = Describe("API test", func() {
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
}) })
Context("External gRPC calls", func() {
It("calculate embeddings with huggingface", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
})
})
Context("backends", func() { Context("backends", func() {
It("runs rwkv completion", func() { It("runs rwkv completion", func() {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {

View File

@ -73,6 +73,9 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
return ss, err return ss, err
} else { } else {
reply, err := inferenceModel.Predict(o.Context, opts) reply, err := inferenceModel.Predict(o.Context, opts)
if err != nil {
return "", err
}
return reply.Message, err return reply.Message, err
} }
} }

View File

@ -3,7 +3,9 @@ package tts
// This is a wrapper to statisfy the GRPC service interface // This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import ( import (
"fmt"
"os" "os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/grpc/base" "github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
@ -16,6 +18,9 @@ type Piper struct {
} }
func (sd *Piper) Load(opts *pb.ModelOptions) error { func (sd *Piper) Load(opts *pb.ModelOptions) error {
if filepath.Ext(opts.Model) != ".onnx" {
return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.Model)
}
var err error var err error
// Note: the Model here is a path to a directory containing the model files // Note: the Model here is a path to a directory containing the model files
sd.piper, err = New(opts.LibrarySearchPath) sd.piper, err = New(opts.LibrarySearchPath)

View File

@ -206,10 +206,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
res, err := client.LoadModel(o.context, &options) res, err := client.LoadModel(o.context, &options)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not load model: %w", err)
} }
if !res.Success { if !res.Success {
return nil, fmt.Errorf("could not load model: %s", res.Message) return nil, fmt.Errorf("could not load model (no success): %s", res.Message)
} }
return client, nil return client, nil
@ -289,7 +289,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
err = multierror.Append(err, modelerr) err = multierror.Append(err, modelerr)
log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error()) log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error())
} else if model == nil { } else if model == nil {
err = multierror.Append(err, modelerr) err = multierror.Append(err, fmt.Errorf("backend returned no usable model"))
log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model") log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model")
} }
} }