mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat: add external grpc and model autoloading
This commit is contained in:
parent
1d2ae46ddc
commit
94916749c5
@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
|
@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
||||
}
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithContext(o.Context),
|
||||
model.WithModelFile(c.ImageGenerationAssets),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1,14 +1,17 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||
@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
|
||||
model.WithContext(o.Context),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
if c.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
}
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
42
api/backend/transcript.go
Normal file
42
api/backend/transcript.go
Normal file
@ -0,0 +1,42 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModelFile(c.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return nil, fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
Threads: uint32(c.Threads),
|
||||
})
|
||||
}
|
72
api/backend/tts.go
Normal file
72
api/backend/tts.go
Normal file
@ -0,0 +1,72 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
||||
func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
|
||||
opts := []model.Option{
|
||||
model.WithBackendString(model.PiperBackend),
|
||||
model.WithModelFile(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
}
|
||||
|
||||
for k, v := range o.ExternalGRPCBackends {
|
||||
opts = append(opts, model.WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if piperModel == nil {
|
||||
return "", nil, fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
||||
filePath := filepath.Join(o.AudioDir, fileName)
|
||||
|
||||
modelPath := filepath.Join(o.Loader.ModelPath, modelFile)
|
||||
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Dst: filePath,
|
||||
})
|
||||
|
||||
return filePath, res, err
|
||||
}
|
@ -4,13 +4,15 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/json-iterator/go"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
case <-c.Done():
|
||||
return
|
||||
case op := <-g.C:
|
||||
utils.ResetDownloadTimers()
|
||||
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
||||
|
||||
// updates the status with an error
|
||||
@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||
displayDownload(fileName, current, total, percentage)
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
var err error
|
||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||
if op.galleryName != "" {
|
||||
if strings.Contains(op.galleryName, "@") {
|
||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
}
|
||||
} else {
|
||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
||||
}
|
||||
@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
}()
|
||||
}
|
||||
|
||||
var lastProgress time.Time = time.Now()
|
||||
var startTime time.Time = time.Now()
|
||||
|
||||
func displayDownload(fileName string, current string, total string, percentage float64) {
|
||||
currentTime := time.Now()
|
||||
|
||||
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
||||
|
||||
lastProgress = currentTime
|
||||
|
||||
// calculate ETA based on percentage and elapsed time
|
||||
var eta time.Duration
|
||||
if percentage > 0 {
|
||||
elapsed := currentTime.Sub(startTime)
|
||||
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
||||
}
|
||||
|
||||
if total != "" {
|
||||
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
||||
} else {
|
||||
log.Debug().Msgf("Downloading: %s", current)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type galleryModel struct {
|
||||
gallery.GalleryModel
|
||||
ID string `json:"id"`
|
||||
@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
|
||||
}
|
||||
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload)
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload)
|
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,17 +1,10 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@ -20,22 +13,6 @@ type TTSRequest struct {
|
||||
Input string `json:"input" yaml:"input"`
|
||||
}
|
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
counter := 1
|
||||
fileName := baseName + ext
|
||||
|
||||
for {
|
||||
filePath := filepath.Join(dir, fileName)
|
||||
_, err := os.Stat(filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return fileName
|
||||
}
|
||||
|
||||
counter++
|
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||
}
|
||||
}
|
||||
|
||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
piperModel, err := o.Loader.BackendLoader(
|
||||
model.WithBackendString(model.PiperBackend),
|
||||
model.WithModelFile(input.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(o.AssetsDestination))
|
||||
filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if piperModel == nil {
|
||||
return fmt.Errorf("could not load piper model")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed creating audio directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
||||
filePath := filepath.Join(o.AudioDir, fileName)
|
||||
|
||||
modelPath := filepath.Join(o.Loader.ModelPath, input.Model)
|
||||
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
Text: input.Input,
|
||||
Model: modelPath,
|
||||
Dst: filePath,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Download(filePath)
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -9,10 +8,9 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModelFile(config.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(config.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: dst,
|
||||
Language: input.Language,
|
||||
Threads: uint32(config.Threads),
|
||||
})
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -28,6 +28,10 @@ type Option struct {
|
||||
|
||||
BackendAssets embed.FS
|
||||
AssetsDestination string
|
||||
|
||||
ExternalGRPCBackends map[string]string
|
||||
|
||||
AutoloadGalleries bool
|
||||
}
|
||||
|
||||
type AppOption func(*Option)
|
||||
@ -53,6 +57,19 @@ func WithCors(b bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
var EnableGalleriesAutoload = func(o *Option) {
|
||||
o.AutoloadGalleries = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) AppOption {
|
||||
return func(o *Option) {
|
||||
if o.ExternalGRPCBackends == nil {
|
||||
o.ExternalGRPCBackends = make(map[string]string)
|
||||
}
|
||||
o.ExternalGRPCBackends[name] = uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption {
|
||||
return func(o *Option) {
|
||||
o.CORSAllowOrigins = b
|
||||
|
30
main.go
30
main.go
@ -4,6 +4,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
api "github.com/go-skynet/LocalAI/api"
|
||||
@ -40,6 +41,10 @@ func main() {
|
||||
Name: "f16",
|
||||
EnvVars: []string{"F16"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "autoload-galleries",
|
||||
EnvVars: []string{"AUTOLOAD_GALLERIES"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "debug",
|
||||
EnvVars: []string{"DEBUG"},
|
||||
@ -108,6 +113,11 @@ func main() {
|
||||
EnvVars: []string{"BACKEND_ASSETS_PATH"},
|
||||
Value: "/tmp/localai/backend_data",
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "external-grpc-backends",
|
||||
Usage: "A list of external grpc backends",
|
||||
EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "context-size",
|
||||
Usage: "Default context size of the model",
|
||||
@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
UsageText: `local-ai [options]`,
|
||||
Copyright: "Ettore Di Giacinto",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
app, err := api.App(
|
||||
|
||||
opts := []options.AppOption{
|
||||
options.WithConfigFile(ctx.String("config-file")),
|
||||
options.WithJSONStringPreload(ctx.String("preload-models")),
|
||||
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||
@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
options.WithThreads(ctx.Int("threads")),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||
options.WithUploadLimitMB(ctx.Int("upload-limit")))
|
||||
options.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||
}
|
||||
|
||||
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||
// split ":" to get backend name and the uri
|
||||
for _, v := range externalgRPC {
|
||||
backend := v[:strings.IndexByte(v, ':')]
|
||||
uri := v[strings.IndexByte(v, ':')+1:]
|
||||
opts = append(opts, options.WithExternalBackend(backend, uri))
|
||||
}
|
||||
|
||||
if ctx.Bool("autoload-galleries") {
|
||||
opts = append(opts, options.EnableGalleriesAutoload)
|
||||
}
|
||||
|
||||
app, err := api.App(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -18,23 +18,15 @@ type Gallery struct {
|
||||
|
||||
// Installs a model from the gallery (galleryname@modelname)
|
||||
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
config, err := GetGalleryConfigFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
installName := model.Name
|
||||
if req.Name != "" {
|
||||
model.Name = req.Name
|
||||
installName = req.Name
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
@ -45,20 +37,58 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
||||
return err
|
||||
}
|
||||
|
||||
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil {
|
||||
if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model, err := FindGallery(models, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) {
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
for _, model := range models {
|
||||
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
|
||||
return applyModel(model)
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no gallery found with name %q", name)
|
||||
}
|
||||
|
||||
// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins)
|
||||
func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||
models, err := AvailableGalleryModels(galleries, basePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
var model *GalleryModel
|
||||
for _, m := range models {
|
||||
if name == m.Name {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
|
||||
if model == nil {
|
||||
return fmt.Errorf("no model found with name %q", name)
|
||||
}
|
||||
|
||||
return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus)
|
||||
}
|
||||
|
||||
// List available models
|
||||
|
@ -19,8 +19,6 @@ import (
|
||||
process "github.com/mudler/go-processmanager"
|
||||
)
|
||||
|
||||
const tokenizerSuffix = ".tokenizer.json"
|
||||
|
||||
const (
|
||||
LlamaBackend = "llama"
|
||||
BloomzBackend = "bloomz"
|
||||
@ -45,7 +43,6 @@ const (
|
||||
StableDiffusionBackend = "stablediffusion"
|
||||
PiperBackend = "piper"
|
||||
LCHuggingFaceBackend = "langchain-huggingface"
|
||||
//GGLLMFalconBackend = "falcon"
|
||||
)
|
||||
|
||||
var AutoLoadBackends []string = []string{
|
||||
@ -75,45 +72,28 @@ func (ml *ModelLoader) StopGRPC() {
|
||||
}
|
||||
}
|
||||
|
||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||
// It also loads the model
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
|
||||
return func(s string) (*grpc.Client, error) {
|
||||
log.Debug().Msgf("Loading GRPC Model", backend, *o)
|
||||
|
||||
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
|
||||
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
|
||||
// Make sure the process is executable
|
||||
if err := os.Chmod(grpcProcess, 0755); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
|
||||
port, err := freeport.GetFreePort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverAddress := fmt.Sprintf("localhost:%d", port)
|
||||
|
||||
log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress)
|
||||
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
|
||||
|
||||
grpcControlProcess := process.New(
|
||||
process.WithTemporaryStateDir(),
|
||||
process.WithName(grpcProcess),
|
||||
process.WithArgs("--addr", serverAddress))
|
||||
|
||||
ml.grpcProcesses[o.modelFile] = grpcControlProcess
|
||||
ml.grpcProcesses[id] = grpcControlProcess
|
||||
|
||||
if err := grpcControlProcess.Run(); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
|
||||
// clean up process
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
@ -128,7 +108,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
|
||||
log.Debug().Msgf("Could not tail stderr")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
@ -137,13 +117,71 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
|
||||
log.Debug().Msgf("Could not tail stdout")
|
||||
}
|
||||
for line := range t.Lines {
|
||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
|
||||
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||
// It also loads the model
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
|
||||
return func(s string) (*grpc.Client, error) {
|
||||
log.Debug().Msgf("Loading GRPC Model", backend, *o)
|
||||
|
||||
var client *grpc.Client
|
||||
|
||||
getFreeAddress := func() (string, error) {
|
||||
port, err := freeport.GetFreePort()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||
}
|
||||
return fmt.Sprintf("127.0.0.1:%d", port), nil
|
||||
}
|
||||
|
||||
// Check if the backend is provided as external
|
||||
if uri, ok := o.externalBackends[backend]; ok {
|
||||
log.Debug().Msgf("Loading external backend: %s", uri)
|
||||
// check if uri is a file or a address
|
||||
if _, err := os.Stat(uri); err == nil {
|
||||
serverAddress, err := getFreeAddress()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||
}
|
||||
// Make sure the process is executable
|
||||
if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service Started")
|
||||
|
||||
client := grpc.NewClient(serverAddress)
|
||||
client = grpc.NewClient(serverAddress)
|
||||
} else {
|
||||
// address
|
||||
client = grpc.NewClient(uri)
|
||||
}
|
||||
} else {
|
||||
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
||||
}
|
||||
|
||||
serverAddress, err := getFreeAddress()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||
}
|
||||
|
||||
// Make sure the process is executable
|
||||
if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("GRPC Service Started")
|
||||
|
||||
client = grpc.NewClient(serverAddress)
|
||||
}
|
||||
|
||||
// Wait for the service to start up
|
||||
ready := false
|
||||
@ -158,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
|
||||
|
||||
if !ready {
|
||||
log.Debug().Msgf("GRPC Service NOT ready")
|
||||
log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive())
|
||||
log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:"))
|
||||
|
||||
log.Debug().Msgf(grpcControlProcess.ExitCode())
|
||||
|
||||
return nil, fmt.Errorf("grpc service not ready")
|
||||
}
|
||||
|
||||
@ -189,6 +222,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
||||
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
|
||||
|
||||
backend := strings.ToLower(o.backendString)
|
||||
|
||||
// if an external backend is provided, use it
|
||||
_, externalBackendExists := o.externalBackends[backend]
|
||||
if externalBackendExists {
|
||||
return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o))
|
||||
}
|
||||
|
||||
switch backend {
|
||||
case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend,
|
||||
MPTBackend, Gpt2Backend, FalconBackend,
|
||||
@ -209,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
||||
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
||||
o := NewOptions(opts...)
|
||||
|
||||
log.Debug().Msgf("Loading model '%s' greedly", o.modelFile)
|
||||
|
||||
// Is this really needed? BackendLoader already does this
|
||||
ml.mu.Lock()
|
||||
if m := ml.checkIsLoaded(o.modelFile); m != nil {
|
||||
@ -221,16 +259,29 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
||||
ml.mu.Unlock()
|
||||
var err error
|
||||
|
||||
for _, b := range AutoLoadBackends {
|
||||
log.Debug().Msgf("[%s] Attempting to load", b)
|
||||
// autoload also external backends
|
||||
allBackendsToAutoLoad := []string{}
|
||||
allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...)
|
||||
for _, b := range o.externalBackends {
|
||||
allBackendsToAutoLoad = append(allBackendsToAutoLoad, b)
|
||||
}
|
||||
log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", "))
|
||||
|
||||
model, modelerr := ml.BackendLoader(
|
||||
for _, b := range allBackendsToAutoLoad {
|
||||
log.Debug().Msgf("[%s] Attempting to load", b)
|
||||
options := []Option{
|
||||
WithBackendString(b),
|
||||
WithModelFile(o.modelFile),
|
||||
WithLoadGRPCLLMModelOpts(o.gRPCOptions),
|
||||
WithThreads(o.threads),
|
||||
WithAssetDir(o.assetDir),
|
||||
)
|
||||
}
|
||||
|
||||
for k, v := range o.externalBackends {
|
||||
options = append(options, WithExternalBackend(k, v))
|
||||
}
|
||||
|
||||
model, modelerr := ml.BackendLoader(options...)
|
||||
if modelerr == nil && model != nil {
|
||||
log.Debug().Msgf("[%s] Loads OK", b)
|
||||
return model, nil
|
||||
|
@ -14,10 +14,21 @@ type Options struct {
|
||||
context context.Context
|
||||
|
||||
gRPCOptions *pb.ModelOptions
|
||||
|
||||
externalBackends map[string]string
|
||||
}
|
||||
|
||||
type Option func(*Options)
|
||||
|
||||
func WithExternalBackend(name string, uri string) Option {
|
||||
return func(o *Options) {
|
||||
if o.externalBackends == nil {
|
||||
o.externalBackends = make(map[string]string)
|
||||
}
|
||||
o.externalBackends[name] = uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendString(backend string) Option {
|
||||
return func(o *Options) {
|
||||
o.backendString = backend
|
||||
|
37
pkg/utils/logging.go
Normal file
37
pkg/utils/logging.go
Normal file
@ -0,0 +1,37 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var lastProgress time.Time = time.Now()
|
||||
var startTime time.Time = time.Now()
|
||||
|
||||
func ResetDownloadTimers() {
|
||||
lastProgress = time.Now()
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) {
|
||||
currentTime := time.Now()
|
||||
|
||||
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
||||
|
||||
lastProgress = currentTime
|
||||
|
||||
// calculate ETA based on percentage and elapsed time
|
||||
var eta time.Duration
|
||||
if percentage > 0 {
|
||||
elapsed := currentTime.Sub(startTime)
|
||||
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
||||
}
|
||||
|
||||
if total != "" {
|
||||
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
||||
} else {
|
||||
log.Debug().Msgf("Downloading: %s", current)
|
||||
}
|
||||
}
|
||||
}
|
5
tests/models_fixtures/grpc.yaml
Normal file
5
tests/models_fixtures/grpc.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
name: code-search-ada-code-001
|
||||
backend: huggingface
|
||||
embeddings: true
|
||||
parameters:
|
||||
model: all-MiniLM-L6-v2
|
Loading…
Reference in New Issue
Block a user