mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat: add external grpc and model autoloading
This commit is contained in:
parent
1d2ae46ddc
commit
94916749c5
@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
|||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for k, v := range o.ExternalGRPCBackends {
|
||||||
|
opts = append(opts, model.WithExternalBackend(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
if c.Backend == "" {
|
if c.Backend == "" {
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||||
} else {
|
} else {
|
||||||
|
@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
|||||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
|
||||||
}
|
}
|
||||||
|
|
||||||
inferenceModel, err := loader.BackendLoader(
|
opts := []model.Option{
|
||||||
model.WithBackendString(c.Backend),
|
model.WithBackendString(c.Backend),
|
||||||
model.WithAssetDir(o.AssetsDestination),
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
model.WithThreads(uint32(c.Threads)),
|
model.WithThreads(uint32(c.Threads)),
|
||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
model.WithModelFile(c.ImageGenerationAssets),
|
model.WithModelFile(c.ImageGenerationAssets),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range o.ExternalGRPCBackends {
|
||||||
|
opts = append(opts, model.WithExternalBackend(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
inferenceModel, err := loader.BackendLoader(
|
||||||
|
opts...,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
package backend
|
package backend
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
||||||
@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
|
|||||||
model.WithContext(o.Context),
|
model.WithContext(o.Context),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for k, v := range o.ExternalGRPCBackends {
|
||||||
|
opts = append(opts, model.WithExternalBackend(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Backend != "" {
|
||||||
|
opts = append(opts, model.WithBackendString(c.Backend))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||||
|
if o.AutoloadGalleries { // experimental
|
||||||
|
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
|
// if we failed to load the model, we try to download it
|
||||||
|
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.Backend == "" {
|
if c.Backend == "" {
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||||
} else {
|
} else {
|
||||||
opts = append(opts, model.WithBackendString(c.Backend))
|
|
||||||
inferenceModel, err = loader.BackendLoader(opts...)
|
inferenceModel, err = loader.BackendLoader(opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
42
api/backend/transcript.go
Normal file
42
api/backend/transcript.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
|
||||||
|
opts := []model.Option{
|
||||||
|
model.WithBackendString(model.WhisperBackend),
|
||||||
|
model.WithModelFile(c.Model),
|
||||||
|
model.WithContext(o.Context),
|
||||||
|
model.WithThreads(uint32(c.Threads)),
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range o.ExternalGRPCBackends {
|
||||||
|
opts = append(opts, model.WithExternalBackend(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
whisperModel, err := o.Loader.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if whisperModel == nil {
|
||||||
|
return nil, fmt.Errorf("could not load whisper model")
|
||||||
|
}
|
||||||
|
|
||||||
|
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||||
|
Dst: audio,
|
||||||
|
Language: language,
|
||||||
|
Threads: uint32(c.Threads),
|
||||||
|
})
|
||||||
|
}
|
72
api/backend/tts.go
Normal file
72
api/backend/tts.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateUniqueFileName(dir, baseName, ext string) string {
|
||||||
|
counter := 1
|
||||||
|
fileName := baseName + ext
|
||||||
|
|
||||||
|
for {
|
||||||
|
filePath := filepath.Join(dir, fileName)
|
||||||
|
_, err := os.Stat(filePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return fileName
|
||||||
|
}
|
||||||
|
|
||||||
|
counter++
|
||||||
|
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
|
||||||
|
opts := []model.Option{
|
||||||
|
model.WithBackendString(model.PiperBackend),
|
||||||
|
model.WithModelFile(modelFile),
|
||||||
|
model.WithContext(o.Context),
|
||||||
|
model.WithAssetDir(o.AssetsDestination),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range o.ExternalGRPCBackends {
|
||||||
|
opts = append(opts, model.WithExternalBackend(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
piperModel, err := o.Loader.BackendLoader(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if piperModel == nil {
|
||||||
|
return "", nil, fmt.Errorf("could not load piper model")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
||||||
|
filePath := filepath.Join(o.AudioDir, fileName)
|
||||||
|
|
||||||
|
modelPath := filepath.Join(o.Loader.ModelPath, modelFile)
|
||||||
|
|
||||||
|
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
||||||
|
Text: text,
|
||||||
|
Model: modelPath,
|
||||||
|
Dst: filePath,
|
||||||
|
})
|
||||||
|
|
||||||
|
return filePath, res, err
|
||||||
|
}
|
@ -4,13 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
json "github.com/json-iterator/go"
|
json "github.com/json-iterator/go"
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
|||||||
case <-c.Done():
|
case <-c.Done():
|
||||||
return
|
return
|
||||||
case op := <-g.C:
|
case op := <-g.C:
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
|
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
||||||
|
|
||||||
// updates the status with an error
|
// updates the status with an error
|
||||||
@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
|||||||
// displayDownload displays the download progress
|
// displayDownload displays the download progress
|
||||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||||
displayDownload(fileName, current, total, percentage)
|
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||||
if op.galleryName != "" {
|
if op.galleryName != "" {
|
||||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
if strings.Contains(op.galleryName, "@") {
|
||||||
|
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||||
|
} else {
|
||||||
|
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
||||||
}
|
}
|
||||||
@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastProgress time.Time = time.Now()
|
|
||||||
var startTime time.Time = time.Now()
|
|
||||||
|
|
||||||
func displayDownload(fileName string, current string, total string, percentage float64) {
|
|
||||||
currentTime := time.Now()
|
|
||||||
|
|
||||||
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
|
||||||
|
|
||||||
lastProgress = currentTime
|
|
||||||
|
|
||||||
// calculate ETA based on percentage and elapsed time
|
|
||||||
var eta time.Duration
|
|
||||||
if percentage > 0 {
|
|
||||||
elapsed := currentTime.Sub(startTime)
|
|
||||||
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
|
||||||
}
|
|
||||||
|
|
||||||
if total != "" {
|
|
||||||
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
|
||||||
} else {
|
|
||||||
log.Debug().Msgf("Downloading: %s", current)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type galleryModel struct {
|
type galleryModel struct {
|
||||||
gallery.GalleryModel
|
gallery.GalleryModel
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range requests {
|
for _, r := range requests {
|
||||||
|
utils.ResetDownloadTimers()
|
||||||
if r.ID == "" {
|
if r.ID == "" {
|
||||||
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload)
|
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||||
} else {
|
} else {
|
||||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload)
|
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,17 +1,10 @@
|
|||||||
package localai
|
package localai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"github.com/go-skynet/LocalAI/api/backend"
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,22 +13,6 @@ type TTSRequest struct {
|
|||||||
Input string `json:"input" yaml:"input"`
|
Input string `json:"input" yaml:"input"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateUniqueFileName(dir, baseName, ext string) string {
|
|
||||||
counter := 1
|
|
||||||
fileName := baseName + ext
|
|
||||||
|
|
||||||
for {
|
|
||||||
filePath := filepath.Join(dir, fileName)
|
|
||||||
_, err := os.Stat(filePath)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return fileName
|
|
||||||
}
|
|
||||||
|
|
||||||
counter++
|
|
||||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
|
||||||
@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
piperModel, err := o.Loader.BackendLoader(
|
filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o)
|
||||||
model.WithBackendString(model.PiperBackend),
|
|
||||||
model.WithModelFile(input.Model),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
model.WithAssetDir(o.AssetsDestination))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if piperModel == nil {
|
|
||||||
return fmt.Errorf("could not load piper model")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed creating audio directory: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
|
|
||||||
filePath := filepath.Join(o.AudioDir, fileName)
|
|
||||||
|
|
||||||
modelPath := filepath.Join(o.Loader.ModelPath, input.Model)
|
|
||||||
|
|
||||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
|
|
||||||
Text: input.Input,
|
|
||||||
Model: modelPath,
|
|
||||||
Dst: filePath,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Download(filePath)
|
return c.Download(filePath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -9,10 +8,9 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend"
|
||||||
config "github.com/go-skynet/LocalAI/api/config"
|
config "github.com/go-skynet/LocalAI/api/config"
|
||||||
"github.com/go-skynet/LocalAI/api/options"
|
"github.com/go-skynet/LocalAI/api/options"
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
|
||||||
whisperModel, err := o.Loader.BackendLoader(
|
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
||||||
model.WithBackendString(model.WhisperBackend),
|
|
||||||
model.WithModelFile(config.Model),
|
|
||||||
model.WithContext(o.Context),
|
|
||||||
model.WithThreads(uint32(config.Threads)),
|
|
||||||
model.WithAssetDir(o.AssetsDestination))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if whisperModel == nil {
|
|
||||||
return fmt.Errorf("could not load whisper model")
|
|
||||||
}
|
|
||||||
|
|
||||||
tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
|
||||||
Dst: dst,
|
|
||||||
Language: input.Language,
|
|
||||||
Threads: uint32(config.Threads),
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -28,6 +28,10 @@ type Option struct {
|
|||||||
|
|
||||||
BackendAssets embed.FS
|
BackendAssets embed.FS
|
||||||
AssetsDestination string
|
AssetsDestination string
|
||||||
|
|
||||||
|
ExternalGRPCBackends map[string]string
|
||||||
|
|
||||||
|
AutoloadGalleries bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppOption func(*Option)
|
type AppOption func(*Option)
|
||||||
@ -53,6 +57,19 @@ func WithCors(b bool) AppOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var EnableGalleriesAutoload = func(o *Option) {
|
||||||
|
o.AutoloadGalleries = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithExternalBackend(name string, uri string) AppOption {
|
||||||
|
return func(o *Option) {
|
||||||
|
if o.ExternalGRPCBackends == nil {
|
||||||
|
o.ExternalGRPCBackends = make(map[string]string)
|
||||||
|
}
|
||||||
|
o.ExternalGRPCBackends[name] = uri
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithCorsAllowOrigins(b string) AppOption {
|
func WithCorsAllowOrigins(b string) AppOption {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.CORSAllowOrigins = b
|
o.CORSAllowOrigins = b
|
||||||
|
30
main.go
30
main.go
@ -4,6 +4,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
api "github.com/go-skynet/LocalAI/api"
|
api "github.com/go-skynet/LocalAI/api"
|
||||||
@ -40,6 +41,10 @@ func main() {
|
|||||||
Name: "f16",
|
Name: "f16",
|
||||||
EnvVars: []string{"F16"},
|
EnvVars: []string{"F16"},
|
||||||
},
|
},
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "autoload-galleries",
|
||||||
|
EnvVars: []string{"AUTOLOAD_GALLERIES"},
|
||||||
|
},
|
||||||
&cli.BoolFlag{
|
&cli.BoolFlag{
|
||||||
Name: "debug",
|
Name: "debug",
|
||||||
EnvVars: []string{"DEBUG"},
|
EnvVars: []string{"DEBUG"},
|
||||||
@ -108,6 +113,11 @@ func main() {
|
|||||||
EnvVars: []string{"BACKEND_ASSETS_PATH"},
|
EnvVars: []string{"BACKEND_ASSETS_PATH"},
|
||||||
Value: "/tmp/localai/backend_data",
|
Value: "/tmp/localai/backend_data",
|
||||||
},
|
},
|
||||||
|
&cli.StringSliceFlag{
|
||||||
|
Name: "external-grpc-backends",
|
||||||
|
Usage: "A list of external grpc backends",
|
||||||
|
EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"},
|
||||||
|
},
|
||||||
&cli.IntFlag{
|
&cli.IntFlag{
|
||||||
Name: "context-size",
|
Name: "context-size",
|
||||||
Usage: "Default context size of the model",
|
Usage: "Default context size of the model",
|
||||||
@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
UsageText: `local-ai [options]`,
|
UsageText: `local-ai [options]`,
|
||||||
Copyright: "Ettore Di Giacinto",
|
Copyright: "Ettore Di Giacinto",
|
||||||
Action: func(ctx *cli.Context) error {
|
Action: func(ctx *cli.Context) error {
|
||||||
app, err := api.App(
|
|
||||||
|
opts := []options.AppOption{
|
||||||
options.WithConfigFile(ctx.String("config-file")),
|
options.WithConfigFile(ctx.String("config-file")),
|
||||||
options.WithJSONStringPreload(ctx.String("preload-models")),
|
options.WithJSONStringPreload(ctx.String("preload-models")),
|
||||||
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||||
@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
|||||||
options.WithThreads(ctx.Int("threads")),
|
options.WithThreads(ctx.Int("threads")),
|
||||||
options.WithBackendAssets(backendAssets),
|
options.WithBackendAssets(backendAssets),
|
||||||
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||||
options.WithUploadLimitMB(ctx.Int("upload-limit")))
|
options.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||||
|
}
|
||||||
|
|
||||||
|
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||||
|
// split ":" to get backend name and the uri
|
||||||
|
for _, v := range externalgRPC {
|
||||||
|
backend := v[:strings.IndexByte(v, ':')]
|
||||||
|
uri := v[strings.IndexByte(v, ':')+1:]
|
||||||
|
opts = append(opts, options.WithExternalBackend(backend, uri))
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Bool("autoload-galleries") {
|
||||||
|
opts = append(opts, options.EnableGalleriesAutoload)
|
||||||
|
}
|
||||||
|
|
||||||
|
app, err := api.App(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -18,23 +18,15 @@ type Gallery struct {
|
|||||||
|
|
||||||
// Installs a model from the gallery (galleryname@modelname)
|
// Installs a model from the gallery (galleryname@modelname)
|
||||||
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||||
|
|
||||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
|
||||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
|
||||||
|
|
||||||
models, err := AvailableGalleryModels(galleries, basePath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
applyModel := func(model *GalleryModel) error {
|
applyModel := func(model *GalleryModel) error {
|
||||||
config, err := GetGalleryConfigFromURL(model.URL)
|
config, err := GetGalleryConfigFromURL(model.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
installName := model.Name
|
||||||
if req.Name != "" {
|
if req.Name != "" {
|
||||||
model.Name = req.Name
|
installName = req.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||||
@ -45,20 +37,58 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil {
|
if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
models, err := AvailableGalleryModels(galleries, basePath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := FindGallery(models, name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return applyModel(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) {
|
||||||
|
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||||
|
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
|
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
|
||||||
return applyModel(model)
|
return model, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no gallery found with name %q", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins)
|
||||||
|
func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
|
||||||
|
models, err := AvailableGalleryModels(galleries, basePath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||||
|
var model *GalleryModel
|
||||||
|
for _, m := range models {
|
||||||
|
if name == m.Name {
|
||||||
|
model = m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("no model found with name %q", name)
|
if model == nil {
|
||||||
|
return fmt.Errorf("no model found with name %q", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List available models
|
// List available models
|
||||||
|
@ -19,8 +19,6 @@ import (
|
|||||||
process "github.com/mudler/go-processmanager"
|
process "github.com/mudler/go-processmanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const tokenizerSuffix = ".tokenizer.json"
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
LlamaBackend = "llama"
|
LlamaBackend = "llama"
|
||||||
BloomzBackend = "bloomz"
|
BloomzBackend = "bloomz"
|
||||||
@ -45,7 +43,6 @@ const (
|
|||||||
StableDiffusionBackend = "stablediffusion"
|
StableDiffusionBackend = "stablediffusion"
|
||||||
PiperBackend = "piper"
|
PiperBackend = "piper"
|
||||||
LCHuggingFaceBackend = "langchain-huggingface"
|
LCHuggingFaceBackend = "langchain-huggingface"
|
||||||
//GGLLMFalconBackend = "falcon"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var AutoLoadBackends []string = []string{
|
var AutoLoadBackends []string = []string{
|
||||||
@ -75,75 +72,116 @@ func (ml *ModelLoader) StopGRPC() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
|
||||||
|
// Make sure the process is executable
|
||||||
|
if err := os.Chmod(grpcProcess, 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
|
||||||
|
|
||||||
|
log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
|
||||||
|
|
||||||
|
grpcControlProcess := process.New(
|
||||||
|
process.WithTemporaryStateDir(),
|
||||||
|
process.WithName(grpcProcess),
|
||||||
|
process.WithArgs("--addr", serverAddress))
|
||||||
|
|
||||||
|
ml.grpcProcesses[id] = grpcControlProcess
|
||||||
|
|
||||||
|
if err := grpcControlProcess.Run(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
|
||||||
|
// clean up process
|
||||||
|
go func() {
|
||||||
|
c := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||||
|
<-c
|
||||||
|
grpcControlProcess.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msgf("Could not tail stderr")
|
||||||
|
}
|
||||||
|
for line := range t.Lines {
|
||||||
|
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msgf("Could not tail stdout")
|
||||||
|
}
|
||||||
|
for line := range t.Lines {
|
||||||
|
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||||
// It also loads the model
|
// It also loads the model
|
||||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
|
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
|
||||||
return func(s string) (*grpc.Client, error) {
|
return func(s string) (*grpc.Client, error) {
|
||||||
log.Debug().Msgf("Loading GRPC Model", backend, *o)
|
log.Debug().Msgf("Loading GRPC Model", backend, *o)
|
||||||
|
|
||||||
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
|
var client *grpc.Client
|
||||||
|
|
||||||
// Check if the file exists
|
getFreeAddress := func() (string, error) {
|
||||||
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
port, err := freeport.GetFreePort()
|
||||||
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the process is executable
|
|
||||||
if err := os.Chmod(grpcProcess, 0755); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
|
|
||||||
port, err := freeport.GetFreePort()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddress := fmt.Sprintf("localhost:%d", port)
|
|
||||||
|
|
||||||
log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress)
|
|
||||||
|
|
||||||
grpcControlProcess := process.New(
|
|
||||||
process.WithTemporaryStateDir(),
|
|
||||||
process.WithName(grpcProcess),
|
|
||||||
process.WithArgs("--addr", serverAddress))
|
|
||||||
|
|
||||||
ml.grpcProcesses[o.modelFile] = grpcControlProcess
|
|
||||||
|
|
||||||
if err := grpcControlProcess.Run(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// clean up process
|
|
||||||
go func() {
|
|
||||||
c := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
|
||||||
<-c
|
|
||||||
grpcControlProcess.Stop()
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Msgf("Could not tail stderr")
|
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||||
}
|
}
|
||||||
for line := range t.Lines {
|
return fmt.Sprintf("127.0.0.1:%d", port), nil
|
||||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
|
}
|
||||||
|
|
||||||
|
// Check if the backend is provided as external
|
||||||
|
if uri, ok := o.externalBackends[backend]; ok {
|
||||||
|
log.Debug().Msgf("Loading external backend: %s", uri)
|
||||||
|
// check if uri is a file or a address
|
||||||
|
if _, err := os.Stat(uri); err == nil {
|
||||||
|
serverAddress, err := getFreeAddress()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||||
|
}
|
||||||
|
// Make sure the process is executable
|
||||||
|
if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("GRPC Service Started")
|
||||||
|
|
||||||
|
client = grpc.NewClient(serverAddress)
|
||||||
|
} else {
|
||||||
|
// address
|
||||||
|
client = grpc.NewClient(uri)
|
||||||
}
|
}
|
||||||
}()
|
} else {
|
||||||
go func() {
|
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
|
||||||
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
|
// Check if the file exists
|
||||||
|
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddress, err := getFreeAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Msgf("Could not tail stdout")
|
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
|
||||||
}
|
}
|
||||||
for line := range t.Lines {
|
|
||||||
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
|
// Make sure the process is executable
|
||||||
|
if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
log.Debug().Msgf("GRPC Service Started")
|
log.Debug().Msgf("GRPC Service Started")
|
||||||
|
|
||||||
client := grpc.NewClient(serverAddress)
|
client = grpc.NewClient(serverAddress)
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for the service to start up
|
// Wait for the service to start up
|
||||||
ready := false
|
ready := false
|
||||||
@ -158,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
|
|||||||
|
|
||||||
if !ready {
|
if !ready {
|
||||||
log.Debug().Msgf("GRPC Service NOT ready")
|
log.Debug().Msgf("GRPC Service NOT ready")
|
||||||
log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive())
|
|
||||||
log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:"))
|
|
||||||
|
|
||||||
log.Debug().Msgf(grpcControlProcess.ExitCode())
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("grpc service not ready")
|
return nil, fmt.Errorf("grpc service not ready")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,6 +222,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
|||||||
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
|
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
|
||||||
|
|
||||||
backend := strings.ToLower(o.backendString)
|
backend := strings.ToLower(o.backendString)
|
||||||
|
|
||||||
|
// if an external backend is provided, use it
|
||||||
|
_, externalBackendExists := o.externalBackends[backend]
|
||||||
|
if externalBackendExists {
|
||||||
|
return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o))
|
||||||
|
}
|
||||||
|
|
||||||
switch backend {
|
switch backend {
|
||||||
case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend,
|
case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend,
|
||||||
MPTBackend, Gpt2Backend, FalconBackend,
|
MPTBackend, Gpt2Backend, FalconBackend,
|
||||||
@ -209,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
|
|||||||
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
||||||
o := NewOptions(opts...)
|
o := NewOptions(opts...)
|
||||||
|
|
||||||
log.Debug().Msgf("Loading model '%s' greedly", o.modelFile)
|
|
||||||
|
|
||||||
// Is this really needed? BackendLoader already does this
|
// Is this really needed? BackendLoader already does this
|
||||||
ml.mu.Lock()
|
ml.mu.Lock()
|
||||||
if m := ml.checkIsLoaded(o.modelFile); m != nil {
|
if m := ml.checkIsLoaded(o.modelFile); m != nil {
|
||||||
@ -221,16 +259,29 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
|
|||||||
ml.mu.Unlock()
|
ml.mu.Unlock()
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
for _, b := range AutoLoadBackends {
|
// autoload also external backends
|
||||||
log.Debug().Msgf("[%s] Attempting to load", b)
|
allBackendsToAutoLoad := []string{}
|
||||||
|
allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...)
|
||||||
|
for _, b := range o.externalBackends {
|
||||||
|
allBackendsToAutoLoad = append(allBackendsToAutoLoad, b)
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", "))
|
||||||
|
|
||||||
model, modelerr := ml.BackendLoader(
|
for _, b := range allBackendsToAutoLoad {
|
||||||
|
log.Debug().Msgf("[%s] Attempting to load", b)
|
||||||
|
options := []Option{
|
||||||
WithBackendString(b),
|
WithBackendString(b),
|
||||||
WithModelFile(o.modelFile),
|
WithModelFile(o.modelFile),
|
||||||
WithLoadGRPCLLMModelOpts(o.gRPCOptions),
|
WithLoadGRPCLLMModelOpts(o.gRPCOptions),
|
||||||
WithThreads(o.threads),
|
WithThreads(o.threads),
|
||||||
WithAssetDir(o.assetDir),
|
WithAssetDir(o.assetDir),
|
||||||
)
|
}
|
||||||
|
|
||||||
|
for k, v := range o.externalBackends {
|
||||||
|
options = append(options, WithExternalBackend(k, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
model, modelerr := ml.BackendLoader(options...)
|
||||||
if modelerr == nil && model != nil {
|
if modelerr == nil && model != nil {
|
||||||
log.Debug().Msgf("[%s] Loads OK", b)
|
log.Debug().Msgf("[%s] Loads OK", b)
|
||||||
return model, nil
|
return model, nil
|
||||||
|
@ -14,10 +14,21 @@ type Options struct {
|
|||||||
context context.Context
|
context context.Context
|
||||||
|
|
||||||
gRPCOptions *pb.ModelOptions
|
gRPCOptions *pb.ModelOptions
|
||||||
|
|
||||||
|
externalBackends map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(*Options)
|
type Option func(*Options)
|
||||||
|
|
||||||
|
func WithExternalBackend(name string, uri string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
if o.externalBackends == nil {
|
||||||
|
o.externalBackends = make(map[string]string)
|
||||||
|
}
|
||||||
|
o.externalBackends[name] = uri
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithBackendString(backend string) Option {
|
func WithBackendString(backend string) Option {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
o.backendString = backend
|
o.backendString = backend
|
||||||
|
37
pkg/utils/logging.go
Normal file
37
pkg/utils/logging.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var lastProgress time.Time = time.Now()
|
||||||
|
var startTime time.Time = time.Now()
|
||||||
|
|
||||||
|
func ResetDownloadTimers() {
|
||||||
|
lastProgress = time.Now()
|
||||||
|
startTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) {
|
||||||
|
currentTime := time.Now()
|
||||||
|
|
||||||
|
if currentTime.Sub(lastProgress) >= 5*time.Second {
|
||||||
|
|
||||||
|
lastProgress = currentTime
|
||||||
|
|
||||||
|
// calculate ETA based on percentage and elapsed time
|
||||||
|
var eta time.Duration
|
||||||
|
if percentage > 0 {
|
||||||
|
elapsed := currentTime.Sub(startTime)
|
||||||
|
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
|
||||||
|
}
|
||||||
|
|
||||||
|
if total != "" {
|
||||||
|
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msgf("Downloading: %s", current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
5
tests/models_fixtures/grpc.yaml
Normal file
5
tests/models_fixtures/grpc.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
name: code-search-ada-code-001
|
||||||
|
backend: huggingface
|
||||||
|
embeddings: true
|
||||||
|
parameters:
|
||||||
|
model: all-MiniLM-L6-v2
|
Loading…
Reference in New Issue
Block a user