package backend import ( "fmt" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" ) func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { if !c.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") } modelFile := c.Model grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error opts := modelOpts(c, o, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.AssetsDestination), model.WithModel(modelFile), model.WithContext(o.Context), }) if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { opts = append(opts, model.WithBackendString(c.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { return nil, err } var fn func() ([]float32, error) switch model := inferenceModel.(type) { case *grpc.Client: fn = func() ([]float32, error) { predictOptions := gRPCPredictOpts(c, loader.ModelPath) if len(tokens) > 0 { embeds := []int32{} for _, t := range tokens { embeds = append(embeds, int32(t)) } predictOptions.EmbeddingTokens = embeds res, err := model.Embeddings(o.Context, predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } predictOptions.Embeddings = s res, err := model.Embeddings(o.Context, predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } default: fn = func() ([]float32, error) { return nil, fmt.Errorf("embeddings not supported by the backend") } } return func() ([]float32, error) { embeds, err := fn() if err != nil { return embeds, err } // Remove trailing 0s for i := len(embeds) - 1; i >= 0; i-- { if embeds[i] == 0.0 { embeds = embeds[:i] } else { break } } return embeds, nil }, nil }