package backend import ( "context" "fmt" "sync" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" bert "github.com/go-skynet/go-bert.cpp" ) func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { if !c.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") } modelFile := c.Model grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error opts := []model.Option{ model.WithLoadGRPCOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.AssetsDestination), model.WithModelFile(modelFile), } if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { opts = append(opts, model.WithBackendString(c.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { return nil, err } var fn func() ([]float32, error) switch model := inferenceModel.(type) { case *grpc.Client: fn = func() ([]float32, error) { predictOptions := gRPCPredictOpts(c, loader.ModelPath) if len(tokens) > 0 { embeds := []int32{} for _, t := range tokens { embeds = append(embeds, int32(t)) } predictOptions.EmbeddingTokens = embeds res, err := model.Embeddings(context.TODO(), predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } predictOptions.Embeddings = s res, err := model.Embeddings(context.TODO(), predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } // bert embeddings case *bert.Bert: fn = func() ([]float32, error) { if len(tokens) > 0 { return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) } return model.Embeddings(s, bert.SetThreads(c.Threads)) } default: fn = func() ([]float32, error) { return nil, fmt.Errorf("embeddings not supported by the backend") } } return func() ([]float32, error) { // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 l := Lock(modelFile) defer l.Unlock() embeds, err := fn() if err != nil { return embeds, err } // Remove trailing 0s for i := len(embeds) - 1; i >= 0; i-- { if embeds[i] == 0.0 { embeds = embeds[:i] } else { break } } return embeds, nil }, nil }