package backend import ( "fmt" "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" ) func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { modelFile := backendConfig.Model grpcOpts := gRPCModelOpts(backendConfig) var inferenceModel interface{} var err error opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithThreads(uint32(*backendConfig.Threads)), model.WithAssetDir(appConfig.AssetsDestination), model.WithModel(modelFile), model.WithContext(appConfig.Context), }) if backendConfig.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { opts = append(opts, model.WithBackendString(backendConfig.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { return nil, err } var fn func() ([]float32, error) switch model := inferenceModel.(type) { case grpc.Backend: fn = func() ([]float32, error) { predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath) if len(tokens) > 0 { embeds := []int32{} for _, t := range tokens { embeds = append(embeds, int32(t)) } predictOptions.EmbeddingTokens = embeds res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } predictOptions.Embeddings = s res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } default: fn = func() ([]float32, error) { return nil, fmt.Errorf("embeddings not supported by the backend") } } return func() ([]float32, error) { embeds, err := fn() if err != nil { return embeds, err } // Remove trailing 0s for i := len(embeds) - 1; i >= 0; i-- { if embeds[i] == 0.0 { embeds = embeds[:i] } else { break } } return embeds, nil }, nil }