refactor: move remaining api packages to core (#1731)

* core 1

* api/openai/files fix

* core 2 - core/config

* move over core api.go and tests to the start of core/http

* move over localai specific endpoints to core/http, begin the service/endpoint split there

* refactor big chunk on the plane

* refactor chunk 2 on plane, next step: port and modify changes to request.go

* easy fixes for request.go, major changes not done yet

* lintfix

* json tag lintfix?

* gitignore and .keep files

* strange fix attempt: rename the config dir?
This commit is contained in:
Dave 2024-03-01 10:19:53 -05:00 committed by GitHub
parent 316de82f51
commit 1c312685aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 1440 additions and 1206 deletions

4
.gitignore vendored
View File

@ -21,6 +21,7 @@ local-ai
!charts/* !charts/*
# prevent above rules from omitting the api/localai folder # prevent above rules from omitting the api/localai folder
!api/localai !api/localai
!core/**/localai
# Ignore models # Ignore models
models/* models/*
@ -34,6 +35,7 @@ release/
.idea .idea
# Generated during build # Generated during build
backend-assets/ backend-assets/*
!backend-assets/.keep
prepare prepare
/ggml-metal.metal /ggml-metal.metal

View File

@ -44,6 +44,8 @@ BUILD_ID?=git
TEST_DIR=/tmp/test TEST_DIR=/tmp/test
TEST_FLAKES?=5
RANDOM := $(shell bash -c 'echo $$RANDOM') RANDOM := $(shell bash -c 'echo $$RANDOM')
VERSION?=$(shell git describe --always --tags || echo "dev" ) VERSION?=$(shell git describe --always --tags || echo "dev" )
@ -337,7 +339,7 @@ test: prepare test-models/testmodel grpcs
export GO_TAGS="tts stablediffusion" export GO_TAGS="tts stablediffusion"
$(MAKE) prepare-test $(MAKE) prepare-test
HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ HUGGINGFACE_GRPC=$(abspath ./)/backend/python/sentencetransformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts 5 --fail-fast -v -r $(TEST_PATHS) $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama && !llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS)
$(MAKE) test-gpt4all $(MAKE) test-gpt4all
$(MAKE) test-llama $(MAKE) test-llama
$(MAKE) test-llama-gguf $(MAKE) test-llama-gguf

View File

@ -1,162 +0,0 @@
package localai
import (
"context"
"fmt"
"strings"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/core/options"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type BackendMonitor struct {
configLoader *config.ConfigLoader
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
options: options,
}
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.options.Loader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
input := new(BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return "", err
}
config, exists := bm.configLoader.GetConfig(input.Model)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = input.Model
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
model := bm.options.Loader.CheckIsLoaded(backendId)
if model == "" {
return fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return c.JSON(proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
})
}
return c.JSON(status)
}
}
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendId, err := bm.getModelLoaderIDFromCtx(c)
if err != nil {
return err
}
return bm.options.Loader.ShutdownModel(backendId)
}
}

View File

@ -1,326 +0,0 @@
package localai
import (
"context"
"fmt"
"os"
"slices"
"strings"
"sync"
json "github.com/json-iterator/go"
"gopkg.in/yaml.v3"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type galleryOp struct {
req gallery.GalleryModel
id string
galleries []gallery.Gallery
galleryName string
}
type galleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}
type galleryApplier struct {
modelPath string
sync.Mutex
C chan galleryOp
statuses map[string]*galleryOpStatus
}
func NewGalleryService(modelPath string) *galleryApplier {
return &galleryApplier{
modelPath: modelPath,
C: make(chan galleryOp),
statuses: make(map[string]*galleryOpStatus),
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" {
if strings.Contains(op.galleryName, "@") {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
}
} else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cm.LoadConfigs(g.modelPath)
if err != nil {
updateError(err)
continue
}
err = cm.Preload(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cm, galleries, requests)
}
/// Endpoint Service
type ModelGalleryService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *galleryApplier
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
return ModelGalleryService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.getAllStatus())
}
}
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- galleryOp{
req: input.GalleryModel,
id: uuid.String(),
galleryName: input.ID,
galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

0
configuration/.keep Normal file
View File

View File

@ -3,36 +3,36 @@ package backend
import ( import (
"fmt" "fmt"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
) )
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
if !c.Embeddings { if !backendConfig.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
} }
modelFile := c.Model modelFile := backendConfig.Model
grpcOpts := gRPCModelOpts(c) grpcOpts := gRPCModelOpts(backendConfig)
var inferenceModel interface{} var inferenceModel interface{}
var err error var err error
opts := modelOpts(c, o, []model.Option{ opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), model.WithThreads(uint32(backendConfig.Threads)),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile), model.WithModel(modelFile),
model.WithContext(o.Context), model.WithContext(appConfig.Context),
}) })
if c.Backend == "" { if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err = loader.GreedyLoader(opts...)
} else { } else {
opts = append(opts, model.WithBackendString(c.Backend)) opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...) inferenceModel, err = loader.BackendLoader(opts...)
} }
if err != nil { if err != nil {
@ -43,7 +43,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
switch model := inferenceModel.(type) { switch model := inferenceModel.(type) {
case grpc.Backend: case grpc.Backend:
fn = func() ([]float32, error) { fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(c, loader.ModelPath) predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
if len(tokens) > 0 { if len(tokens) > 0 {
embeds := []int32{} embeds := []int32{}
@ -52,7 +52,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
} }
predictOptions.EmbeddingTokens = embeds predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(o.Context, predictOptions) res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,7 +61,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
} }
predictOptions.Embeddings = s predictOptions.Embeddings = s
res, err := model.Embeddings(o.Context, predictOptions) res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,33 +1,33 @@
package backend package backend
import ( import (
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
) )
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := modelOpts(c, o, []model.Option{ opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(c.Backend), model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(c.Threads)), model.WithThreads(uint32(backendConfig.Threads)),
model.WithContext(o.Context), model.WithContext(appConfig.Context),
model.WithModel(c.Model), model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{ model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA, CUDA: backendConfig.CUDA || backendConfig.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType, SchedulerType: backendConfig.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType, PipelineType: backendConfig.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale, CFGScale: backendConfig.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter, LoraAdapter: backendConfig.LoraAdapter,
LoraScale: c.LoraScale, LoraScale: backendConfig.LoraScale,
LoraBase: c.LoraBase, LoraBase: backendConfig.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG, IMG2IMG: backendConfig.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel, CLIPModel: backendConfig.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder, CLIPSubfolder: backendConfig.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip), CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet, ControlNet: backendConfig.Diffusers.ControlNet,
}), }),
}) })
@ -40,19 +40,19 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
fn := func() error { fn := func() error {
_, err := inferenceModel.GenerateImage( _, err := inferenceModel.GenerateImage(
o.Context, appConfig.Context,
&proto.GenerateImageRequest{ &proto.GenerateImageRequest{
Height: int32(height), Height: int32(height),
Width: int32(width), Width: int32(width),
Mode: int32(mode), Mode: int32(mode),
Step: int32(step), Step: int32(step),
Seed: int32(seed), Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip), CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
PositivePrompt: positive_prompt, PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt, NegativePrompt: negative_prompt,
Dst: dst, Dst: dst,
Src: src, Src: src,
EnableParameters: c.Diffusers.EnableParameters, EnableParameters: backendConfig.Diffusers.EnableParameters,
}) })
return err return err
} }

View File

@ -8,8 +8,8 @@ import (
"sync" "sync"
"unicode/utf8" "unicode/utf8"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
@ -26,7 +26,7 @@ type TokenUsage struct {
Completion int Completion int
} }
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model modelFile := c.Model
grpcOpts := gRPCModelOpts(c) grpcOpts := gRPCModelOpts(c)
@ -140,7 +140,7 @@ func ModelInference(ctx context.Context, s string, images []string, loader *mode
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{} var mu sync.Mutex = sync.Mutex{}
func Finetune(config config.Config, input, prediction string) string { func Finetune(config config.BackendConfig, input, prediction string) string {
if config.Echo { if config.Echo {
prediction = input + prediction prediction = input + prediction
} }

View File

@ -4,19 +4,17 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/go-skynet/LocalAI/core/config"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
) )
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option { func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option {
if o.SingleBackend { if so.SingleBackend {
opts = append(opts, model.WithSingleActiveBackend()) opts = append(opts, model.WithSingleActiveBackend())
} }
if o.ParallelBackendRequests { if so.ParallelBackendRequests {
opts = append(opts, model.EnableParallelRequests) opts = append(opts, model.EnableParallelRequests)
} }
@ -28,14 +26,14 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.
opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) opts = append(opts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime))
} }
for k, v := range o.ExternalGRPCBackends { for k, v := range so.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v)) opts = append(opts, model.WithExternalBackend(k, v))
} }
return opts return opts
} }
func gRPCModelOpts(c config.Config) *pb.ModelOptions { func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512 b := 512
if c.Batch != 0 { if c.Batch != 0 {
b = c.Batch b = c.Batch
@ -84,7 +82,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
} }
} }
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOptions {
promptCachePath := "" promptCachePath := ""
if c.PromptCachePath != "" { if c.PromptCachePath != "" {
p := filepath.Join(modelPath, c.PromptCachePath) p := filepath.Join(modelPath, c.PromptCachePath)

View File

@ -4,25 +4,24 @@ import (
"context" "context"
"fmt" "fmt"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
) )
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) { func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) {
opts := modelOpts(c, o, []model.Option{ opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend), model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model), model.WithModel(backendConfig.Model),
model.WithContext(o.Context), model.WithContext(appConfig.Context),
model.WithThreads(uint32(c.Threads)), model.WithThreads(uint32(backendConfig.Threads)),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(appConfig.AssetsDestination),
}) })
whisperModel, err := o.Loader.BackendLoader(opts...) whisperModel, err := ml.BackendLoader(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -34,6 +33,6 @@ func ModelTranscription(audio, language string, loader *model.ModelLoader, c con
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio, Dst: audio,
Language: language, Language: language,
Threads: uint32(c.Threads), Threads: uint32(backendConfig.Threads),
}) })
} }

View File

@ -6,8 +6,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
@ -29,22 +29,22 @@ func generateUniqueFileName(dir, baseName, ext string) string {
} }
} }
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option, c config.Config) (string, *proto.Result, error) { func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
bb := backend bb := backend
if bb == "" { if bb == "" {
bb = model.PiperBackend bb = model.PiperBackend
} }
grpcOpts := gRPCModelOpts(c) grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.Config{}, o, []model.Option{ opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb), model.WithBackendString(bb),
model.WithModel(modelFile), model.WithModel(modelFile),
model.WithContext(o.Context), model.WithContext(appConfig.Context),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts), model.WithLoadGRPCLoadModelOpts(grpcOpts),
}) })
piperModel, err := o.Loader.BackendLoader(opts...) piperModel, err := loader.BackendLoader(opts...)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@ -53,19 +53,19 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
return "", nil, fmt.Errorf("could not load piper model") return "", nil, fmt.Errorf("could not load piper model")
} }
if err := os.MkdirAll(o.AudioDir, 0755); err != nil { if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err) return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
} }
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") fileName := generateUniqueFileName(appConfig.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName) filePath := filepath.Join(appConfig.AudioDir, fileName)
// If the model file is not empty, we pass it joined with the model path // If the model file is not empty, we pass it joined with the model path
modelPath := "" modelPath := ""
if modelFile != "" { if modelFile != "" {
if bb != model.TransformersMusicGen { if bb != model.TransformersMusicGen {
modelPath = filepath.Join(o.Loader.ModelPath, modelFile) modelPath = filepath.Join(loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { if err := utils.VerifyPath(modelPath, appConfig.ModelPath); err != nil {
return "", nil, err return "", nil, err
} }
} else { } else {

View File

@ -1,4 +1,4 @@
package options package config
import ( import (
"context" "context"
@ -6,16 +6,14 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Option struct { type ApplicationConfig struct {
Context context.Context Context context.Context
ConfigFile string ConfigFile string
Loader *model.ModelLoader ModelPath string
UploadLimitMB, Threads, ContextSize int UploadLimitMB, Threads, ContextSize int
F16 bool F16 bool
Debug, DisableMessage bool Debug, DisableMessage bool
@ -27,7 +25,6 @@ type Option struct {
PreloadModelsFromPath string PreloadModelsFromPath string
CORSAllowOrigins string CORSAllowOrigins string
ApiKeys []string ApiKeys []string
Metrics *metrics.Metrics
ModelLibraryURL string ModelLibraryURL string
@ -52,10 +49,10 @@ type Option struct {
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
} }
type AppOption func(*Option) type AppOption func(*ApplicationConfig)
func NewOptions(o ...AppOption) *Option { func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
opt := &Option{ opt := &ApplicationConfig{
Context: context.Background(), Context: context.Background(),
UploadLimitMB: 15, UploadLimitMB: 15,
Threads: 1, Threads: 1,
@ -70,63 +67,69 @@ func NewOptions(o ...AppOption) *Option {
} }
func WithModelsURL(urls ...string) AppOption { func WithModelsURL(urls ...string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.ModelsURL = urls o.ModelsURL = urls
} }
} }
func WithModelPath(path string) AppOption {
return func(o *ApplicationConfig) {
o.ModelPath = path
}
}
func WithCors(b bool) AppOption { func WithCors(b bool) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.CORS = b o.CORS = b
} }
} }
func WithModelLibraryURL(url string) AppOption { func WithModelLibraryURL(url string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.ModelLibraryURL = url o.ModelLibraryURL = url
} }
} }
var EnableWatchDog = func(o *Option) { var EnableWatchDog = func(o *ApplicationConfig) {
o.WatchDog = true o.WatchDog = true
} }
var EnableWatchDogIdleCheck = func(o *Option) { var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
o.WatchDog = true o.WatchDog = true
o.WatchDogIdle = true o.WatchDogIdle = true
} }
var EnableWatchDogBusyCheck = func(o *Option) { var EnableWatchDogBusyCheck = func(o *ApplicationConfig) {
o.WatchDog = true o.WatchDog = true
o.WatchDogBusy = true o.WatchDogBusy = true
} }
func SetWatchDogBusyTimeout(t time.Duration) AppOption { func SetWatchDogBusyTimeout(t time.Duration) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.WatchDogBusyTimeout = t o.WatchDogBusyTimeout = t
} }
} }
func SetWatchDogIdleTimeout(t time.Duration) AppOption { func SetWatchDogIdleTimeout(t time.Duration) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.WatchDogIdleTimeout = t o.WatchDogIdleTimeout = t
} }
} }
var EnableSingleBackend = func(o *Option) { var EnableSingleBackend = func(o *ApplicationConfig) {
o.SingleBackend = true o.SingleBackend = true
} }
var EnableParallelBackendRequests = func(o *Option) { var EnableParallelBackendRequests = func(o *ApplicationConfig) {
o.ParallelBackendRequests = true o.ParallelBackendRequests = true
} }
var EnableGalleriesAutoload = func(o *Option) { var EnableGalleriesAutoload = func(o *ApplicationConfig) {
o.AutoloadGalleries = true o.AutoloadGalleries = true
} }
func WithExternalBackend(name string, uri string) AppOption { func WithExternalBackend(name string, uri string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
if o.ExternalGRPCBackends == nil { if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string) o.ExternalGRPCBackends = make(map[string]string)
} }
@ -135,25 +138,25 @@ func WithExternalBackend(name string, uri string) AppOption {
} }
func WithCorsAllowOrigins(b string) AppOption { func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.CORSAllowOrigins = b o.CORSAllowOrigins = b
} }
} }
func WithBackendAssetsOutput(out string) AppOption { func WithBackendAssetsOutput(out string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.AssetsDestination = out o.AssetsDestination = out
} }
} }
func WithBackendAssets(f embed.FS) AppOption { func WithBackendAssets(f embed.FS) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.BackendAssets = f o.BackendAssets = f
} }
} }
func WithStringGalleries(galls string) AppOption { func WithStringGalleries(galls string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
if galls == "" { if galls == "" {
log.Debug().Msgf("no galleries to load") log.Debug().Msgf("no galleries to load")
o.Galleries = []gallery.Gallery{} o.Galleries = []gallery.Gallery{}
@ -168,102 +171,96 @@ func WithStringGalleries(galls string) AppOption {
} }
func WithGalleries(galleries []gallery.Gallery) AppOption { func WithGalleries(galleries []gallery.Gallery) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.Galleries = append(o.Galleries, galleries...) o.Galleries = append(o.Galleries, galleries...)
} }
} }
func WithContext(ctx context.Context) AppOption { func WithContext(ctx context.Context) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.Context = ctx o.Context = ctx
} }
} }
func WithYAMLConfigPreload(configFile string) AppOption { func WithYAMLConfigPreload(configFile string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.PreloadModelsFromPath = configFile o.PreloadModelsFromPath = configFile
} }
} }
func WithJSONStringPreload(configFile string) AppOption { func WithJSONStringPreload(configFile string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.PreloadJSONModels = configFile o.PreloadJSONModels = configFile
} }
} }
func WithConfigFile(configFile string) AppOption { func WithConfigFile(configFile string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.ConfigFile = configFile o.ConfigFile = configFile
} }
} }
func WithModelLoader(loader *model.ModelLoader) AppOption {
return func(o *Option) {
o.Loader = loader
}
}
func WithUploadLimitMB(limit int) AppOption { func WithUploadLimitMB(limit int) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.UploadLimitMB = limit o.UploadLimitMB = limit
} }
} }
func WithThreads(threads int) AppOption { func WithThreads(threads int) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.Threads = threads o.Threads = threads
} }
} }
func WithContextSize(ctxSize int) AppOption { func WithContextSize(ctxSize int) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.ContextSize = ctxSize o.ContextSize = ctxSize
} }
} }
func WithF16(f16 bool) AppOption { func WithF16(f16 bool) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.F16 = f16 o.F16 = f16
} }
} }
func WithDebug(debug bool) AppOption { func WithDebug(debug bool) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.Debug = debug o.Debug = debug
} }
} }
func WithDisableMessage(disableMessage bool) AppOption { func WithDisableMessage(disableMessage bool) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.DisableMessage = disableMessage o.DisableMessage = disableMessage
} }
} }
func WithAudioDir(audioDir string) AppOption { func WithAudioDir(audioDir string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.AudioDir = audioDir o.AudioDir = audioDir
} }
} }
func WithImageDir(imageDir string) AppOption { func WithImageDir(imageDir string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.ImageDir = imageDir o.ImageDir = imageDir
} }
} }
func WithUploadDir(uploadDir string) AppOption { func WithUploadDir(uploadDir string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.UploadDir = uploadDir o.UploadDir = uploadDir
} }
} }
func WithApiKeys(apiKeys []string) AppOption { func WithApiKeys(apiKeys []string) AppOption {
return func(o *Option) { return func(o *ApplicationConfig) {
o.ApiKeys = apiKeys o.ApiKeys = apiKeys
} }
} }
func WithMetrics(meter *metrics.Metrics) AppOption { // func WithMetrics(meter *metrics.Metrics) AppOption {
return func(o *Option) { // return func(o *StartupOptions) {
o.Metrics = meter // o.Metrics = meter
} // }
} // }

View File

@ -9,15 +9,16 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type Config struct { type BackendConfig struct {
PredictionOptions `yaml:"parameters"` schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"` Name string `yaml:"name"`
F16 bool `yaml:"f16"` F16 bool `yaml:"f16"`
Threads int `yaml:"threads"` Threads int `yaml:"threads"`
@ -159,37 +160,55 @@ type TemplateConfig struct {
Functions string `yaml:"function"` Functions string `yaml:"function"`
} }
type ConfigLoader struct { func (c *BackendConfig) SetFunctionCallString(s string) {
configs map[string]Config
sync.Mutex
}
func (c *Config) SetFunctionCallString(s string) {
c.functionCallString = s c.functionCallString = s
} }
func (c *Config) SetFunctionCallNameString(s string) { func (c *BackendConfig) SetFunctionCallNameString(s string) {
c.functionCallNameString = s c.functionCallNameString = s
} }
func (c *Config) ShouldUseFunctions() bool { func (c *BackendConfig) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
} }
func (c *Config) ShouldCallSpecificFunction() bool { func (c *BackendConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0 return len(c.functionCallNameString) > 0
} }
func (c *Config) FunctionToCall() string { func (c *BackendConfig) FunctionToCall() string {
return c.functionCallNameString return c.functionCallNameString
} }
func defaultPredictOptions(modelFile string) schema.PredictionOptions {
return schema.PredictionOptions{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
}
}
func DefaultConfig(modelFile string) *BackendConfig {
return &BackendConfig{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
////// Config Loader ////////
type BackendConfigLoader struct {
configs map[string]BackendConfig
sync.Mutex
}
// Load a config file for a model // Load a config file for a model
func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ctx int, f16 bool) (*Config, error) { func LoadBackendConfigFileByName(modelName, modelPath string, cl *BackendConfigLoader, debug bool, threads, ctx int, f16 bool) (*BackendConfig, error) {
// Load a config file if present after the model name // Load a config file if present after the model name
modelConfig := filepath.Join(modelPath, modelName+".yaml") modelConfig := filepath.Join(modelPath, modelName+".yaml")
var cfg *Config var cfg *BackendConfig
defaults := func() { defaults := func() {
cfg = DefaultConfig(modelName) cfg = DefaultConfig(modelName)
@ -199,13 +218,13 @@ func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ct
cfg.Debug = debug cfg.Debug = debug
} }
cfgExisting, exists := cm.GetConfig(modelName) cfgExisting, exists := cl.GetBackendConfig(modelName)
if !exists { if !exists {
if _, err := os.Stat(modelConfig); err == nil { if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil { if err := cl.LoadBackendConfig(modelConfig); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
} }
cfgExisting, exists = cm.GetConfig(modelName) cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists { if exists {
cfg = &cfgExisting cfg = &cfgExisting
} else { } else {
@ -238,29 +257,13 @@ func Load(modelName, modelPath string, cm *ConfigLoader, debug bool, threads, ct
return cfg, nil return cfg, nil
} }
func defaultPredictOptions(modelFile string) PredictionOptions { func NewBackendConfigLoader() *BackendConfigLoader {
return PredictionOptions{ return &BackendConfigLoader{
TopP: 0.7, configs: make(map[string]BackendConfig),
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
} }
} }
func ReadBackendConfigFile(file string) ([]*BackendConfig, error) {
func DefaultConfig(modelFile string) *Config { c := &[]*BackendConfig{}
return &Config{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
configs: make(map[string]Config),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err) return nil, fmt.Errorf("cannot read config file: %w", err)
@ -272,8 +275,8 @@ func ReadConfigFile(file string) ([]*Config, error) {
return *c, nil return *c, nil
} }
func ReadConfig(file string) (*Config, error) { func ReadBackendConfig(file string) (*BackendConfig, error) {
c := &Config{} c := &BackendConfig{}
f, err := os.ReadFile(file) f, err := os.ReadFile(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err) return nil, fmt.Errorf("cannot read config file: %w", err)
@ -285,10 +288,10 @@ func ReadConfig(file string) (*Config, error) {
return c, nil return c, nil
} }
func (cm *ConfigLoader) LoadConfigFile(file string) error { func (cm *BackendConfigLoader) LoadBackendConfigFile(file string) error {
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()
c, err := ReadConfigFile(file) c, err := ReadBackendConfigFile(file)
if err != nil { if err != nil {
return fmt.Errorf("cannot load config file: %w", err) return fmt.Errorf("cannot load config file: %w", err)
} }
@ -299,49 +302,49 @@ func (cm *ConfigLoader) LoadConfigFile(file string) error {
return nil return nil
} }
func (cm *ConfigLoader) LoadConfig(file string) error { func (cl *BackendConfigLoader) LoadBackendConfig(file string) error {
cm.Lock() cl.Lock()
defer cm.Unlock() defer cl.Unlock()
c, err := ReadConfig(file) c, err := ReadBackendConfig(file)
if err != nil { if err != nil {
return fmt.Errorf("cannot read config file: %w", err) return fmt.Errorf("cannot read config file: %w", err)
} }
cm.configs[c.Name] = *c cl.configs[c.Name] = *c
return nil return nil
} }
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
cm.Lock() cl.Lock()
defer cm.Unlock() defer cl.Unlock()
v, exists := cm.configs[m] v, exists := cl.configs[m]
return v, exists return v, exists
} }
func (cm *ConfigLoader) GetAllConfigs() []Config { func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
cm.Lock() cl.Lock()
defer cm.Unlock() defer cl.Unlock()
var res []Config var res []BackendConfig
for _, v := range cm.configs { for _, v := range cl.configs {
res = append(res, v) res = append(res, v)
} }
return res return res
} }
func (cm *ConfigLoader) ListConfigs() []string { func (cl *BackendConfigLoader) ListBackendConfigs() []string {
cm.Lock() cl.Lock()
defer cm.Unlock() defer cl.Unlock()
var res []string var res []string
for k := range cm.configs { for k := range cl.configs {
res = append(res, k) res = append(res, k)
} }
return res return res
} }
// Preload prepare models if they are not local but url or huggingface repositories // Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error { func (cl *BackendConfigLoader) Preload(modelPath string) error {
cm.Lock() cl.Lock()
defer cm.Unlock() defer cl.Unlock()
status := func(fileName, current, total string, percent float64) { status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent) utils.DisplayDownloadFunction(fileName, current, total, percent)
@ -349,7 +352,7 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
log.Info().Msgf("Preloading models from %s", modelPath) log.Info().Msgf("Preloading models from %s", modelPath)
for i, config := range cm.configs { for i, config := range cl.configs {
// Download files and verify their SHA // Download files and verify their SHA
for _, file := range config.DownloadFiles { for _, file := range config.DownloadFiles {
@ -381,25 +384,25 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
} }
} }
cc := cm.configs[i] cc := cl.configs[i]
c := &cc c := &cc
c.PredictionOptions.Model = md5Name c.PredictionOptions.Model = md5Name
cm.configs[i] = *c cl.configs[i] = *c
} }
if cm.configs[i].Name != "" { if cl.configs[i].Name != "" {
log.Info().Msgf("Model name: %s", cm.configs[i].Name) log.Info().Msgf("Model name: %s", cl.configs[i].Name)
} }
if cm.configs[i].Description != "" { if cl.configs[i].Description != "" {
log.Info().Msgf("Model description: %s", cm.configs[i].Description) log.Info().Msgf("Model description: %s", cl.configs[i].Description)
} }
if cm.configs[i].Usage != "" { if cl.configs[i].Usage != "" {
log.Info().Msgf("Model usage: \n%s", cm.configs[i].Usage) log.Info().Msgf("Model usage: \n%s", cl.configs[i].Usage)
} }
} }
return nil return nil
} }
func (cm *ConfigLoader) LoadConfigs(path string) error { func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string) error {
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()
entries, err := os.ReadDir(path) entries, err := os.ReadDir(path)
@ -419,7 +422,7 @@ func (cm *ConfigLoader) LoadConfigs(path string) error {
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
continue continue
} }
c, err := ReadConfig(filepath.Join(path, file.Name())) c, err := ReadBackendConfig(filepath.Join(path, file.Name()))
if err == nil { if err == nil {
cm.configs[c.Name] = *c cm.configs[c.Name] = *c
} }

View File

@ -4,8 +4,7 @@ import (
"os" "os"
. "github.com/go-skynet/LocalAI/core/config" . "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() {
Context("Test Read configuration functions", func() { Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE") configFile = os.Getenv("CONFIG_FILE")
It("Test ReadConfigFile", func() { It("Test ReadConfigFile", func() {
config, err := ReadConfigFile(configFile) config, err := ReadBackendConfigFile(configFile)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
Expect(config).ToNot(BeNil()) Expect(config).ToNot(BeNil())
// two configs in config.yaml // two configs in config.yaml
@ -28,29 +27,26 @@ var _ = Describe("Test cases for config related functions", func() {
}) })
It("Test LoadConfigs", func() { It("Test LoadConfigs", func() {
cm := NewConfigLoader() cm := NewBackendConfigLoader()
opts := options.NewOptions() opts := NewApplicationConfig()
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) err := cm.LoadBackendConfigsFromPath(opts.ModelPath)
options.WithModelLoader(modelLoader)(opts)
err := cm.LoadConfigs(opts.Loader.ModelPath)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
Expect(cm.ListConfigs()).ToNot(BeNil()) Expect(cm.ListBackendConfigs()).ToNot(BeNil())
// config should includes gpt4all models's api.config // config should includes gpt4all models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all")) Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all"))
// config should includes gpt2 models's api.config // config should includes gpt2 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2")) Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all-2"))
// config should includes text-embedding-ada-002 models's api.config // config should includes text-embedding-ada-002 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002")) Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002"))
// config should includes rwkv_test models's api.config // config should includes rwkv_test models's api.config
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test")) Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test"))
// config should includes whisper-1 models's api.config // config should includes whisper-1 models's api.config
Expect(cm.ListConfigs()).To(ContainElements("whisper-1")) Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1"))
}) })
}) })
}) })

View File

@ -3,122 +3,29 @@ package http
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"os" "os"
"strings" "strings"
"github.com/go-skynet/LocalAI/api/localai" "github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/api/openai" "github.com/go-skynet/LocalAI/core/http/endpoints/openai"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/startup"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover" "github.com/gofiber/fiber/v2/middleware/recover"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
) )
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) {
options := options.NewOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)
cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if err := cl.Preload(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}
if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
options.Loader,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
options.Loader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
return options, cl, nil
}
func App(opts ...options.AppOption) (*fiber.App, error) {
options, cl, err := Startup(opts...)
if err != nil {
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
// Return errors as JSON responses // Return errors as JSON responses
app := fiber.New(fiber.Config{ app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.DisableMessage, DisableStartupMessage: appConfig.DisableMessage,
// Override default error handler // Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error { ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500 // Status code defaults to 500
@ -139,7 +46,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
}, },
}) })
if options.Debug { if appConfig.Debug {
app.Use(logger.New(logger.Config{ app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
})) }))
@ -147,17 +54,25 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
// Default middleware config // Default middleware config
if !options.Debug { if !appConfig.Debug {
app.Use(recover.New()) app.Use(recover.New())
} }
if options.Metrics != nil { metricsService, err := services.NewLocalAIMetricsService()
app.Use(metrics.APIMiddleware(options.Metrics)) if err != nil {
return nil, err
}
if metricsService != nil {
app.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
app.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
} }
// Auth middleware checking if API key is valid. If no API key is set, no auth is required. // Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error { auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) == 0 { if len(appConfig.ApiKeys) == 0 {
return c.Next() return c.Next()
} }
@ -172,10 +87,10 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
} }
// Add file keys to options.ApiKeys // Add file keys to options.ApiKeys
options.ApiKeys = append(options.ApiKeys, fileKeys...) appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
} }
if len(options.ApiKeys) == 0 { if len(appConfig.ApiKeys) == 0 {
return c.Next() return c.Next()
} }
@ -189,7 +104,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
} }
apiKey := authHeaderParts[1] apiKey := authHeaderParts[1]
for _, key := range options.ApiKeys { for _, key := range appConfig.ApiKeys {
if apiKey == key { if apiKey == key {
return c.Next() return c.Next()
} }
@ -199,20 +114,20 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
} }
if options.CORS { if appConfig.CORS {
var c func(ctx *fiber.Ctx) error var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" { if appConfig.CORSAllowOrigins == "" {
c = cors.New() c = cors.New()
} else { } else {
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins})
} }
app.Use(c) app.Use(c)
} }
// LocalAI API endpoints // LocalAI API endpoints
galleryService := localai.NewGalleryService(options.Loader.ModelPath) galleryService := services.NewGalleryService(appConfig.ModelPath)
galleryService.Start(options.Context, cl) galleryService.Start(appConfig.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error { app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct { return c.JSON(struct {
@ -220,69 +135,63 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
}{Version: internal.PrintableVersion()}) }{Version: internal.PrintableVersion()})
}) })
// Make sure directories exists
os.MkdirAll(options.ImageDir, 0755)
os.MkdirAll(options.AudioDir, 0755)
os.MkdirAll(options.UploadDir, 0755)
os.MkdirAll(options.Loader.ModelPath, 0755)
// Load upload json // Load upload json
openai.LoadUploadConfig(options.UploadDir) openai.LoadUploadConfig(appConfig.UploadDir)
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())
// openAI compatible API endpoint // openAI compatible API endpoint
// chat // chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
// edit // edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", auth, openai.EditEndpoint(cl, options)) app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
// files // files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options)) app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options)) app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options)) app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, options)) app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
// completion // completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
// embeddings // embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
// audio // audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))
// images // images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))
if options.ImageDir != "" { if appConfig.ImageDir != "" {
app.Static("/generated-images", options.ImageDir) app.Static("/generated-images", appConfig.ImageDir)
} }
if options.AudioDir != "" { if appConfig.AudioDir != "" {
app.Static("/generated-audio", options.AudioDir) app.Static("/generated-audio", appConfig.AudioDir)
} }
ok := func(c *fiber.Ctx) error { ok := func(c *fiber.Ctx) error {
@ -294,15 +203,15 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
app.Get("/readyz", ok) app.Get("/readyz", ok)
// Experimental Backend Statistics Module // Experimental Backend Statistics Module
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
// models // models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/metrics", metrics.MetricsHandler()) app.Get("/metrics", localai.LocalAIMetricsEndpoint())
return app, nil return app, nil
} }

View File

@ -13,9 +13,10 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/go-skynet/LocalAI/core/config"
. "github.com/go-skynet/LocalAI/core/http" . "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
@ -127,25 +128,33 @@ var backendAssets embed.FS
var _ = Describe("API test", func() { var _ = Describe("API test", func() {
var app *fiber.App var app *fiber.App
var modelLoader *model.ModelLoader
var client *openai.Client var client *openai.Client
var client2 *openaigo.Client var client2 *openaigo.Client
var c context.Context var c context.Context
var cancel context.CancelFunc var cancel context.CancelFunc
var tmpdir string var tmpdir string
var modelDir string
var bcl *config.BackendConfigLoader
var ml *model.ModelLoader
var applicationConfig *config.ApplicationConfig
commonOpts := []options.AppOption{ commonOpts := []config.AppOption{
options.WithDebug(true), config.WithDebug(true),
options.WithDisableMessage(true), config.WithDisableMessage(true),
} }
Context("API with ephemeral models", func() { Context("API with ephemeral models", func() {
BeforeEach(func() {
BeforeEach(func(sc SpecContext) {
var err error var err error
tmpdir, err = os.MkdirTemp("", "") tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir) modelDir = filepath.Join(tmpdir, "models")
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
err = os.Mkdir(backendAssetsDir, 0755)
Expect(err).ToNot(HaveOccurred())
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
g := []gallery.GalleryModel{ g := []gallery.GalleryModel{
@ -172,16 +181,18 @@ var _ = Describe("API test", func() {
}, },
} }
metricsService, err := metrics.SetupMetrics() bcl, ml, applicationConfig, err = startup.Startup(
append(commonOpts,
config.WithContext(c),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithBackendAssets(backendAssets),
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App( app, err = App(bcl, ml, applicationConfig)
append(commonOpts,
options.WithMetrics(metricsService),
options.WithContext(c),
options.WithGalleries(galleries),
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -198,15 +209,21 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred()) }, "2m").ShouldNot(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func(sc SpecContext) {
cancel() cancel()
app.Shutdown() if app != nil {
os.RemoveAll(tmpdir) err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadDir(tmpdir)
Expect(err).To(HaveOccurred())
}) })
Context("Applying models", func() { Context("Applying models", func() {
It("applies models from a gallery", func() {
It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available") models := getModels("http://127.0.0.1:9090/models/available")
Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
@ -228,10 +245,10 @@ var _ = Describe("API test", func() {
}, "360s", "10s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error")) Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml")) dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml")) _, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]interface{}{}
@ -253,6 +270,7 @@ var _ = Describe("API test", func() {
} }
}) })
It("overrides models", func() { It("overrides models", func() {
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Name: "bert", Name: "bert",
@ -270,7 +288,7 @@ var _ = Describe("API test", func() {
return response["processed"].(bool) return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]interface{}{}
@ -294,7 +312,7 @@ var _ = Describe("API test", func() {
return response["processed"].(bool) return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
content := map[string]interface{}{} content := map[string]interface{}{}
@ -483,8 +501,11 @@ var _ = Describe("API test", func() {
var err error var err error
tmpdir, err = os.MkdirTemp("", "") tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
modelDir = filepath.Join(tmpdir, "models")
backendAssetsDir := filepath.Join(tmpdir, "backend-assets")
err = os.Mkdir(backendAssetsDir, 0755)
Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
galleries := []gallery.Gallery{ galleries := []gallery.Gallery{
@ -494,21 +515,20 @@ var _ = Describe("API test", func() {
}, },
} }
metricsService, err := metrics.SetupMetrics() bcl, ml, applicationConfig, err = startup.Startup(
Expect(err).ToNot(HaveOccurred())
app, err = App(
append(commonOpts, append(commonOpts,
options.WithContext(c), config.WithContext(c),
options.WithMetrics(metricsService), config.WithAudioDir(tmpdir),
options.WithAudioDir(tmpdir), config.WithImageDir(tmpdir),
options.WithImageDir(tmpdir), config.WithGalleries(galleries),
options.WithGalleries(galleries), config.WithModelPath(modelDir),
options.WithModelLoader(modelLoader), config.WithBackendAssets(backendAssets),
options.WithBackendAssets(backendAssets), config.WithBackendAssetsOutput(tmpdir))...,
options.WithBackendAssetsOutput(tmpdir))...,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -527,8 +547,14 @@ var _ = Describe("API test", func() {
AfterEach(func() { AfterEach(func() {
cancel() cancel()
app.Shutdown() if app != nil {
os.RemoveAll(tmpdir) err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
Expect(err).ToNot(HaveOccurred())
_, err = os.ReadDir(tmpdir)
Expect(err).To(HaveOccurred())
}) })
It("installs and is capable to run tts", Label("tts"), func() { It("installs and is capable to run tts", Label("tts"), func() {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
@ -599,20 +625,20 @@ var _ = Describe("API test", func() {
Context("API query", func() { Context("API query", func() {
BeforeEach(func() { BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) modelPath := os.Getenv("MODELS_PATH")
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics() var err error
Expect(err).ToNot(HaveOccurred())
app, err = App( bcl, ml, applicationConfig, err = startup.Startup(
append(commonOpts, append(commonOpts,
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), config.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
options.WithContext(c), config.WithContext(c),
options.WithModelLoader(modelLoader), config.WithModelPath(modelPath),
options.WithMetrics(metricsService),
)...) )...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -630,7 +656,10 @@ var _ = Describe("API test", func() {
}) })
AfterEach(func() { AfterEach(func() {
cancel() cancel()
app.Shutdown() if app != nil {
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
}) })
It("returns the models list", func() { It("returns the models list", func() {
models, err := client.ListModels(context.TODO()) models, err := client.ListModels(context.TODO())
@ -811,20 +840,20 @@ var _ = Describe("API test", func() {
Context("Config file", func() { Context("Config file", func() {
BeforeEach(func() { BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) modelPath := os.Getenv("MODELS_PATH")
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
metricsService, err := metrics.SetupMetrics() var err error
Expect(err).ToNot(HaveOccurred()) bcl, ml, applicationConfig, err = startup.Startup(
app, err = App(
append(commonOpts, append(commonOpts,
options.WithContext(c), config.WithContext(c),
options.WithMetrics(metricsService), config.WithModelPath(modelPath),
options.WithModelLoader(modelLoader), config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = App(bcl, ml, applicationConfig)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -840,7 +869,10 @@ var _ = Describe("API test", func() {
}) })
AfterEach(func() { AfterEach(func() {
cancel() cancel()
app.Shutdown() if app != nil {
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
}) })
It("can generate chat completions from config file (list1)", func() { It("can generate chat completions from config file (list1)", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})

View File

@ -0,0 +1,36 @@
package localai
import (
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/gofiber/fiber/v2"
)
func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(resp)
}
}
func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}

View File

@ -0,0 +1,146 @@
package localai
import (
"encoding/json"
"fmt"
"slices"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.GetAllStatus())
}
}
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.C <- gallery.GalleryOp{
Req: input.GalleryModel,
Id: uuid.String(),
GalleryName: input.ID,
Galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
log.Debug().Msgf("Models found from galleries: %+v", models)
for _, m := range models {
log.Debug().Msgf("Model found from galleries: %+v", m)
}
dat, err := json.Marshal(models)
if err != nil {
return err
}
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View File

@ -0,0 +1,43 @@
package localai
import (
"time"
"github.com/go-skynet/LocalAI/core/services"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func LocalAIMetricsEndpoint() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics"
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}

View File

@ -1,37 +1,32 @@
package localai package localai
import ( import (
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/rs/zerolog/log" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/core/schema"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
) )
type TTSRequest struct { func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Backend string `json:"backend" yaml:"backend"`
}
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(TTSRequest) input := new(schema.TTSRequest)
// Get input data from the request body // Get input data from the request body
if err := c.BodyParser(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, false) modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false)
if err != nil { if err != nil {
modelFile = input.Model modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model) log.Warn().Msgf("Model not found in context: %s", input.Model)
} }
cfg, err := config.Load(modelFile, o.Loader.ModelPath, cm, false, 0, 0, false) cfg, err := config.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cl, false, 0, 0, false)
if err != nil { if err != nil {
modelFile = input.Model modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model) log.Warn().Msgf("Model not found in context: %s", input.Model)
@ -44,7 +39,7 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
cfg.Backend = input.Backend cfg.Backend = input.Backend
} }
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, o.Loader, o, *cfg) filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -9,8 +9,7 @@ import (
"time" "time"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
@ -21,12 +20,12 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
emptyMessage := "" emptyMessage := ""
id := uuid.New().String() id := uuid.New().String()
created := int(time.Now().Unix()) created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
initialMessage := schema.OpenAIResponse{ initialMessage := schema.OpenAIResponse{
ID: id, ID: id,
Created: created, Created: created,
@ -36,7 +35,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
responses <- initialMessage responses <- initialMessage
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{ resp := schema.OpenAIResponse{
ID: id, ID: id,
Created: created, Created: created,
@ -55,9 +54,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}) })
close(responses) close(responses)
} }
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
result := "" result := ""
_, tokenUsage, _ := ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { _, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s result += s
// TODO: Change generated BNF grammar to be compliant with the schema so we can // TODO: Change generated BNF grammar to be compliant with the schema so we can
// stream the result token by token here. // stream the result token by token here.
@ -78,7 +77,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
responses <- initialMessage responses <- initialMessage
result, err := handleQuestion(config, req, o, results[0].arguments, prompt) result, err := handleQuestion(config, req, ml, startupOptions, results[0].arguments, prompt)
if err != nil { if err != nil {
log.Error().Msgf("error handling question: %s", err.Error()) log.Error().Msgf("error handling question: %s", err.Error())
return return
@ -154,12 +153,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
processFunctions := false processFunctions := false
funcs := grammar.Functions{} funcs := grammar.Functions{}
modelFile, input, err := readRequest(c, o, true) modelFile, input, err := readRequest(c, ml, startupOptions, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -252,7 +251,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
FunctionName: i.Name, FunctionName: i.Name,
MessageIndex: messageIndex, MessageIndex: messageIndex,
} }
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
if err != nil { if err != nil {
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err) log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
} else { } else {
@ -320,7 +319,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
templateFile := "" templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model templateFile = config.Model
} }
@ -333,7 +332,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
if templateFile != "" { if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt, SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt, SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput, Input: predInput,
@ -357,9 +356,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
responses := make(chan schema.OpenAIResponse) responses := make(chan schema.OpenAIResponse)
if !processFunctions { if !processFunctions {
go process(predInput, input, config, o.Loader, responses) go process(predInput, input, config, ml, responses)
} else { } else {
go processTools(noActionName, predInput, input, config, o.Loader, responses) go processTools(noActionName, predInput, input, config, ml, responses)
} }
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
@ -413,7 +412,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
// no streaming mode // no streaming mode
default: default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !processFunctions { if !processFunctions {
// no function is called, just reply and use stop as finish reason // no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
@ -425,7 +424,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
switch { switch {
case noActionsToRun: case noActionsToRun:
result, err := handleQuestion(config, input, o, results[0].arguments, predInput) result, err := handleQuestion(config, input, ml, startupOptions, results[0].arguments, predInput)
if err != nil { if err != nil {
log.Error().Msgf("error handling question: %s", err.Error()) log.Error().Msgf("error handling question: %s", err.Error())
return return
@ -506,7 +505,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
} }
func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *options.Option, args, prompt string) (string, error) { func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, args, prompt string) (string, error) {
log.Debug().Msgf("nothing to do, computing a reply") log.Debug().Msgf("nothing to do, computing a reply")
// If there is a message that the LLM already sends as part of the JSON reply, use it // If there is a message that the LLM already sends as part of the JSON reply, use it
@ -535,7 +534,7 @@ func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *optio
images = append(images, m.StringImages...) images = append(images, m.StringImages...)
} }
predFunc, err := backend.ModelInference(input.Context, prompt, images, o.Loader, *config, o, nil) predFunc, err := backend.ModelInference(input.Context, prompt, images, ml, *config, o, nil)
if err != nil { if err != nil {
log.Error().Msgf("inference error: %s", err.Error()) log.Error().Msgf("inference error: %s", err.Error())
return "", err return "", err

View File

@ -9,8 +9,8 @@ import (
"time" "time"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
@ -21,12 +21,12 @@ import (
) )
// https://platform.openai.com/docs/api-reference/completions // https://platform.openai.com/docs/api-reference/completions
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
id := uuid.New().String() id := uuid.New().String()
created := int(time.Now().Unix()) created := int(time.Now().Unix())
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
resp := schema.OpenAIResponse{ resp := schema.OpenAIResponse{
ID: id, ID: id,
Created: created, Created: created,
@ -53,14 +53,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
} }
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, o, true) modelFile, input, err := readRequest(c, ml, appConfig, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
log.Debug().Msgf("`input`: %+v", input) log.Debug().Msgf("`input`: %+v", input)
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -84,7 +84,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
templateFile := "" templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model templateFile = config.Model
} }
@ -100,7 +100,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
predInput := config.PromptStrings[0] predInput := config.PromptStrings[0]
if templateFile != "" { if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput, Input: predInput,
}) })
if err == nil { if err == nil {
@ -111,7 +111,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
responses := make(chan schema.OpenAIResponse) responses := make(chan schema.OpenAIResponse)
go process(predInput, input, config, o.Loader, responses) go process(predInput, input, config, ml, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
@ -153,7 +153,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
for k, i := range config.PromptStrings { for k, i := range config.PromptStrings {
if templateFile != "" { if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt, SystemPrompt: config.SystemPrompt,
Input: i, Input: i,
}) })
@ -164,7 +164,7 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
} }
r, tokenUsage, err := ComputeChoices( r, tokenUsage, err := ComputeChoices(
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil) }, nil)
if err != nil { if err != nil {

View File

@ -6,8 +6,8 @@ import (
"time" "time"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@ -16,14 +16,14 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
modelFile, input, err := readRequest(c, o, true) modelFile, input, err := readRequest(c, ml, appConfig, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := mergeRequestWithConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -33,7 +33,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
templateFile := "" templateFile := ""
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model templateFile = config.Model
} }
@ -46,7 +46,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
for _, i := range config.InputStrings { for _, i := range config.InputStrings {
if templateFile != "" { if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i, Input: i,
Instruction: input.Instruction, Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt, SystemPrompt: config.SystemPrompt,
@ -57,7 +57,7 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
} }
} }
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s}) *c = append(*c, schema.Choice{Text: s})
}, nil) }, nil)
if err != nil { if err != nil {

View File

@ -6,24 +6,25 @@ import (
"time" "time"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/go-skynet/LocalAI/core/options"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// https://platform.openai.com/docs/api-reference/embeddings // https://platform.openai.com/docs/api-reference/embeddings
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readRequest(c, o, true) model, input, err := readRequest(c, ml, appConfig, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := mergeRequestWithConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -33,7 +34,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
for i, s := range config.InputToken { for i, s := range config.InputToken {
// get the model function to call for the result // get the model function to call for the result
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig)
if err != nil { if err != nil {
return err return err
} }
@ -47,7 +48,7 @@ func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
for i, s := range config.InputStrings { for i, s := range config.InputStrings {
// get the model function to call for the result // get the model function to call for the result
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig)
if err != nil { if err != nil {
return err return err
} }

View File

@ -8,8 +8,8 @@ import (
"path/filepath" "path/filepath"
"time" "time"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -62,7 +62,7 @@ func LoadUploadConfig(uploadPath string) {
} }
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create // UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
@ -70,8 +70,8 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
} }
// Check the file size // Check the file size
if file.Size > int64(o.UploadLimitMB*1024*1024) { if file.Size > int64(appConfig.UploadLimitMB*1024*1024) {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB)) return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB))
} }
purpose := c.FormValue("purpose", "") //TODO put in purpose dirs purpose := c.FormValue("purpose", "") //TODO put in purpose dirs
@ -82,7 +82,7 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
// Sanitize the filename to prevent directory traversal // Sanitize the filename to prevent directory traversal
filename := utils.SanitizeFileName(file.Filename) filename := utils.SanitizeFileName(file.Filename)
savePath := filepath.Join(o.UploadDir, filename) savePath := filepath.Join(appConfig.UploadDir, filename)
// Check if file already exists // Check if file already exists
if _, err := os.Stat(savePath); !os.IsNotExist(err) { if _, err := os.Stat(savePath); !os.IsNotExist(err) {
@ -104,13 +104,13 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
} }
uploadedFiles = append(uploadedFiles, f) uploadedFiles = append(uploadedFiles, f)
saveUploadConfig(o.UploadDir) saveUploadConfig(appConfig.UploadDir)
return c.Status(fiber.StatusOK).JSON(f) return c.Status(fiber.StatusOK).JSON(f)
} }
} }
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list // ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
type ListFiles struct { type ListFiles struct {
Data []File Data []File
Object string Object string
@ -150,7 +150,7 @@ func getFileFromRequest(c *fiber.Ctx) (*File, error) {
} }
// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve // GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve
func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c) file, err := getFileFromRequest(c)
if err != nil { if err != nil {
@ -162,7 +162,7 @@ func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.
} }
// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete // DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete
func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
type DeleteStatus struct { type DeleteStatus struct {
Id string Id string
Object string Object string
@ -175,7 +175,7 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
} }
err = os.Remove(filepath.Join(o.UploadDir, file.Filename)) err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil { if err != nil {
// If the file doesn't exist then we should just continue to remove it // If the file doesn't exist then we should just continue to remove it
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
@ -191,7 +191,7 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
} }
} }
saveUploadConfig(o.UploadDir) saveUploadConfig(appConfig.UploadDir)
return c.JSON(DeleteStatus{ return c.JSON(DeleteStatus{
Id: file.ID, Id: file.ID,
Object: "file", Object: "file",
@ -201,14 +201,14 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib
} }
// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents // GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents
func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
file, err := getFileFromRequest(c) file, err := getFileFromRequest(c)
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
} }
fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename)) fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename))
if err != nil { if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
} }

View File

@ -11,8 +11,8 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
utils2 "github.com/go-skynet/LocalAI/pkg/utils" utils2 "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -25,11 +25,11 @@ type ListFiles struct {
Object string Object string
} }
func startUpApp() (app *fiber.App, option *options.Option, loader *config.ConfigLoader) { func startUpApp() (app *fiber.App, option *config.ApplicationConfig, loader *config.BackendConfigLoader) {
// Preparing the mocked objects // Preparing the mocked objects
loader = &config.ConfigLoader{} loader = &config.BackendConfigLoader{}
option = &options.Option{ option = &config.ApplicationConfig{
UploadLimitMB: 10, UploadLimitMB: 10,
UploadDir: "test_dir", UploadDir: "test_dir",
} }
@ -52,9 +52,9 @@ func startUpApp() (app *fiber.App, option *options.Option, loader *config.Config
func TestUploadFileExceedSizeLimit(t *testing.T) { func TestUploadFileExceedSizeLimit(t *testing.T) {
// Preparing the mocked objects // Preparing the mocked objects
loader := &config.ConfigLoader{} loader := &config.BackendConfigLoader{}
option := &options.Option{ option := &config.ApplicationConfig{
UploadLimitMB: 10, UploadLimitMB: 10,
UploadDir: "test_dir", UploadDir: "test_dir",
} }
@ -174,9 +174,9 @@ func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*htt
return app.Test(request) return app.Test(request)
} }
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) (*http.Response, error) { func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
// Create a file that exceeds the limit // Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, o) file := createTestFile(t, fileName, fileSize, appConfig)
// Creating a new HTTP Request // Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose) body, writer := newMultipartFile(file.Name(), tag, purpose)
@ -186,9 +186,9 @@ func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpos
return app.Test(req) return app.Test(req)
} }
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) File { func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
// Create a file that exceeds the limit // Create a file that exceeds the limit
file := createTestFile(t, fileName, fileSize, o) file := createTestFile(t, fileName, fileSize, appConfig)
// Creating a new HTTP Request // Creating a new HTTP Request
body, writer := newMultipartFile(file.Name(), tag, purpose) body, writer := newMultipartFile(file.Name(), tag, purpose)
@ -233,7 +233,7 @@ func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipar
} }
// Helper to create test files // Helper to create test files
func createTestFile(t *testing.T, name string, sizeMB int, option *options.Option) *os.File { func createTestFile(t *testing.T, name string, sizeMB int, option *config.ApplicationConfig) *os.File {
err := os.MkdirAll(option.UploadDir, 0755) err := os.MkdirAll(option.UploadDir, 0755)
if err != nil { if err != nil {

View File

@ -13,12 +13,12 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -59,9 +59,9 @@ func downloadFile(url string) (string, error) {
* *
*/ */
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, o, false) m, input, err := readRequest(c, ml, appConfig, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -71,7 +71,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
} }
log.Debug().Msgf("Loading model: %+v", m) log.Debug().Msgf("Loading model: %+v", m)
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -104,7 +104,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
} }
// Create a temporary file // Create a temporary file
outputFile, err := os.CreateTemp(o.ImageDir, "b64") outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
if err != nil { if err != nil {
return err return err
} }
@ -133,15 +133,15 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
sizeParts := strings.Split(input.Size, "x") sizeParts := strings.Split(input.Size, "x")
if len(sizeParts) != 2 { if len(sizeParts) != 2 {
return fmt.Errorf("Invalid value for 'size'") return fmt.Errorf("invalid value for 'size'")
} }
width, err := strconv.Atoi(sizeParts[0]) width, err := strconv.Atoi(sizeParts[0])
if err != nil { if err != nil {
return fmt.Errorf("Invalid value for 'size'") return fmt.Errorf("invalid value for 'size'")
} }
height, err := strconv.Atoi(sizeParts[1]) height, err := strconv.Atoi(sizeParts[1])
if err != nil { if err != nil {
return fmt.Errorf("Invalid value for 'size'") return fmt.Errorf("invalid value for 'size'")
} }
b64JSON := false b64JSON := false
@ -179,7 +179,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
tempDir := "" tempDir := ""
if !b64JSON { if !b64JSON {
tempDir = o.ImageDir tempDir = appConfig.ImageDir
} }
// Create a temporary file // Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64") outputFile, err := os.CreateTemp(tempDir, "b64")
@ -196,7 +196,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
baseURL := c.BaseURL() baseURL := c.BaseURL()
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o) fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
if err != nil { if err != nil {
return err return err
} }

View File

@ -2,8 +2,8 @@ package openai
import ( import (
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
) )
@ -11,8 +11,8 @@ import (
func ComputeChoices( func ComputeChoices(
req *schema.OpenAIRequest, req *schema.OpenAIRequest,
predInput string, predInput string,
config *config.Config, config *config.BackendConfig,
o *options.Option, o *config.ApplicationConfig,
loader *model.ModelLoader, loader *model.ModelLoader,
cb func(string, *[]schema.Choice), cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {

View File

@ -3,15 +3,15 @@ package openai
import ( import (
"regexp" "regexp"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
models, err := loader.ListModels() models, err := ml.ListModels()
if err != nil { if err != nil {
return err return err
} }
@ -40,7 +40,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
excludeConfigured := c.QueryBool("excludeConfigured", true) excludeConfigured := c.QueryBool("excludeConfigured", true)
// Start with the known configurations // Start with the known configurations
for _, c := range cm.GetAllConfigs() { for _, c := range cl.GetAllBackendConfigs() {
if excludeConfigured { if excludeConfigured {
mm[c.Model] = nil mm[c.Model] = nil
} }

View File

@ -5,13 +5,12 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"strings" "strings"
fiberContext "github.com/go-skynet/LocalAI/api/ctx" "github.com/go-skynet/LocalAI/core/config"
config "github.com/go-skynet/LocalAI/core/config" fiberContext "github.com/go-skynet/LocalAI/core/http/ctx"
options "github.com/go-skynet/LocalAI/core/options"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
@ -19,11 +18,9 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *schema.OpenAIRequest, error) { func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest) input := new(schema.OpenAIRequest)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// Get input data from the request body // Get input data from the request body
if err := c.BodyParser(input); err != nil { if err := c.BodyParser(input); err != nil {
return "", nil, fmt.Errorf("failed parsing request body: %w", err) return "", nil, fmt.Errorf("failed parsing request body: %w", err)
@ -31,9 +28,13 @@ func readRequest(c *fiber.Ctx, o *options.Option, firstModel bool) (string, *sch
received, _ := json.Marshal(input) received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received)) log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, o.Loader, input.Model, firstModel) modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
return modelFile, input, err return modelFile, input, err
} }
@ -50,7 +51,7 @@ func getBase64Image(s string) (string, error) {
defer resp.Body.Close() defer resp.Body.Close()
// read the image data into memory // read the image data into memory
data, err := ioutil.ReadAll(resp.Body) data, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -69,7 +70,7 @@ func getBase64Image(s string) (string, error) {
return "", fmt.Errorf("not valid string") return "", fmt.Errorf("not valid string")
} }
func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) { func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo { if input.Echo {
config.Echo = input.Echo config.Echo = input.Echo
} }
@ -270,8 +271,8 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
} }
} }
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) { func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
cfg, err := config.Load(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16) cfg, err := config.LoadBackendConfigFileByName(modelFile, loader.ModelPath, cm, debug, threads, ctx, f16)
// Set the parameters for the language model prediction // Set the parameters for the language model prediction
updateRequestConfig(cfg, input) updateRequestConfig(cfg, input)

View File

@ -9,22 +9,22 @@ import (
"path/filepath" "path/filepath"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/options" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// https://platform.openai.com/docs/api-reference/audio/create // https://platform.openai.com/docs/api-reference/audio/create
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readRequest(c, o, false) m, input, err := readRequest(c, ml, appConfig, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := mergeRequestWithConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -59,7 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Audio file copied to: %+v", dst) log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
if err != nil { if err != nil {
return err return err
} }

21
core/schema/localai.go Normal file
View File

@ -0,0 +1,21 @@
package schema
import (
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}
type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}
type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Backend string `json:"backend" yaml:"backend"`
}

View File

@ -3,8 +3,6 @@ package schema
import ( import (
"context" "context"
config "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grammar"
) )
@ -108,10 +106,10 @@ type ChatCompletionResponseFormat struct {
} }
type OpenAIRequest struct { type OpenAIRequest struct {
config.PredictionOptions PredictionOptions
Context context.Context Context context.Context `json:"-"`
Cancel context.CancelFunc Cancel context.CancelFunc `json:"-"`
// whisper // whisper
File string `json:"file" validate:"required"` File string `json:"file" validate:"required"`

View File

@ -1,4 +1,4 @@
package config package schema
type PredictionOptions struct { type PredictionOptions struct {

View File

@ -0,0 +1,140 @@
package services
import (
"context"
"fmt"
"strings"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
type BackendMonitor struct {
configLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}
func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
modelLoader: modelLoader,
options: appConfig,
}
}
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bm.configLoader.GetBackendConfig(modelName)
var backendId string
if exists {
backendId = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backendId = modelName
}
if !strings.HasSuffix(backendId, ".bin") {
backendId = fmt.Sprintf("%s.bin", backendId)
}
return backendId, nil
}
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetBackendConfig(model)
var backend string
if exists {
backend = config.Model
} else {
// Last ditch effort: use it raw, see if a backend happens to match.
backend = model
}
if !strings.HasSuffix(backend, ".bin") {
backend = fmt.Sprintf("%s.bin", backend)
}
pid, err := bm.modelLoader.GetGRPCPID(backend)
if err != nil {
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
return nil, err
}
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
backendProcess, err := gopsutil.NewProcess(int32(pid))
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
return nil, err
}
memInfo, err := backendProcess.MemoryInfo()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
return nil, err
}
memPercent, err := backendProcess.MemoryPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
return nil, err
}
cpuPercent, err := backendProcess.CPUPercent()
if err != nil {
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
return nil, err
}
return &schema.BackendMonitorResponse{
MemoryInfo: memInfo,
MemoryPercent: memPercent,
CPUPercent: cpuPercent,
}, nil
}
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return nil, err
}
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
if modelAddr == "" {
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
}
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId)
if slbErr != nil {
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
}
return &proto.StatusResponse{
State: proto.StatusResponse_ERROR,
Memory: &proto.MemoryUsageData{
Total: val.MemoryInfo.VMS,
Breakdown: map[string]uint64{
"gopsutil-RSS": val.MemoryInfo.RSS,
},
},
}, nil
}
return status, nil
}
func (bm BackendMonitor) ShutdownModel(modelName string) error {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
if err != nil {
return err
}
return bm.modelLoader.ShutdownModel(backendId)
}

167
core/services/gallery.go Normal file
View File

@ -0,0 +1,167 @@
package services
import (
"context"
"encoding/json"
"os"
"strings"
"sync"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v2"
)
type GalleryService struct {
modelPath string
sync.Mutex
C chan gallery.GalleryOp
statuses map[string]*gallery.GalleryOpStatus
}
func NewGalleryService(modelPath string) *GalleryService {
return &GalleryService{
modelPath: modelPath,
C: make(chan gallery.GalleryOp),
statuses: make(map[string]*gallery.GalleryOpStatus),
}
}
func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
}
func (g *GalleryService) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *GalleryService) GetStatus(s string) *gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *GalleryService) GetAllStatus() map[string]*gallery.GalleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses
}
func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
utils.ResetDownloadTimers()
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error
updateError := func(e error) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
}
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" {
if strings.Contains(op.GalleryName, "@") {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
}
} else {
err = prepareModel(g.modelPath, op.Req, cl, progressCallback)
}
if err != nil {
updateError(err)
continue
}
// Reload models
err = cl.LoadBackendConfigsFromPath(g.modelPath)
if err != nil {
updateError(err)
continue
}
err = cl.Preload(g.modelPath)
if err != nil {
updateError(err)
continue
}
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
}
}
}()
}
type galleryModel struct {
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
ID string `json:"id"`
}
func processRequests(modelPath, s string, cm *config.BackendConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else {
if strings.Contains(r.ID, "@") {
err = gallery.InstallModelFromGallery(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} else {
err = gallery.InstallModelFromGalleryByName(
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
}
}
}
return err
}
func ApplyGalleryFromFile(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error {
dat, err := os.ReadFile(s)
if err != nil {
return err
}
var requests []galleryModel
if err := yaml.Unmarshal(dat, &requests); err != nil {
return err
}
return processRequests(modelPath, s, cl, galleries, requests)
}
func ApplyGalleryFromString(modelPath, s string, cl *config.BackendConfigLoader, galleries []gallery.Gallery) error {
var requests []galleryModel
err := json.Unmarshal([]byte(s), &requests)
if err != nil {
return err
}
return processRequests(modelPath, s, cl, galleries, requests)
}

54
core/services/metrics.go Normal file
View File

@ -0,0 +1,54 @@
package services
import (
"context"
"github.com/rs/zerolog/log"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/metric"
metricApi "go.opentelemetry.io/otel/sdk/metric"
)
type LocalAIMetricsService struct {
Meter metric.Meter
ApiTimeMetric metric.Float64Histogram
}
func (m *LocalAIMetricsService) ObserveAPICall(method string, path string, duration float64) {
opts := metric.WithAttributes(
attribute.String("method", method),
attribute.String("path", path),
)
m.ApiTimeMetric.Record(context.Background(), duration, opts)
}
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
// If it does not return an error, make sure to call shutdown for proper cleanup.
func NewLocalAIMetricsService() (*LocalAIMetricsService, error) {
exporter, err := prometheus.New()
if err != nil {
return nil, err
}
provider := metricApi.NewMeterProvider(metricApi.WithReader(exporter))
meter := provider.Meter("github.com/go-skynet/LocalAI")
apiTimeMetric, err := meter.Float64Histogram("api_call", metric.WithDescription("api calls"))
if err != nil {
return nil, err
}
return &LocalAIMetricsService{
Meter: meter,
ApiTimeMetric: apiTimeMetric,
}, nil
}
func (lams LocalAIMetricsService) Shutdown() error {
// TODO: Not sure how to actually do this:
//// setupOTelSDK bootstraps the OpenTelemetry pipeline.
//// If it does not return an error, make sure to call shutdown for proper cleanup.
log.Warn().Msgf("LocalAIMetricsService Shutdown called, but OTelSDK proper shutdown not yet implemented?")
return nil
}

View File

@ -0,0 +1,100 @@
package startup
import (
"encoding/json"
"fmt"
"os"
"path"
"github.com/fsnotify/fsnotify"
"github.com/go-skynet/LocalAI/core/config"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
)
type WatchConfigDirectoryCloser func() error
func ReadApiKeysJson(configDir string, appConfig *config.ApplicationConfig) error {
fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json"))
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err == nil {
appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
return nil
}
return err
}
return err
}
func ReadExternalBackendsJson(configDir string, appConfig *config.ApplicationConfig) error {
fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json"))
if err != nil {
return err
}
// Parse JSON content from the file
var fileBackends map[string]string
err = json.Unmarshal(fileContent, &fileBackends)
if err != nil {
return err
}
err = mergo.Merge(&appConfig.ExternalGRPCBackends, fileBackends)
if err != nil {
return err
}
return nil
}
var CONFIG_FILE_UPDATES = map[string]func(configDir string, appConfig *config.ApplicationConfig) error{
"api_keys.json": ReadApiKeysJson,
"external_backends.json": ReadExternalBackendsJson,
}
func WatchConfigDirectory(configDir string, appConfig *config.ApplicationConfig) (WatchConfigDirectoryCloser, error) {
if len(configDir) == 0 {
return nil, fmt.Errorf("configDir blank")
}
configWatcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err)
}
ret := func() error {
configWatcher.Close()
return nil
}
// Start listening for events.
go func() {
for {
select {
case event, ok := <-configWatcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) {
for targetName, watchFn := range CONFIG_FILE_UPDATES {
if event.Name == targetName {
err := watchFn(configDir, appConfig)
log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err)
}
}
}
case _, ok := <-configWatcher.Errors:
if !ok {
return
}
log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err)
}
}
}()
// Add a path.
err = configWatcher.Add(configDir)
if err != nil {
return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err)
}
return ret, nil
}

128
core/startup/startup.go Normal file
View File

@ -0,0 +1,128 @@
package startup
import (
"fmt"
"os"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"
pkgStartup "github.com/go-skynet/LocalAI/pkg/startup"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.ModelLoader, *config.ApplicationConfig, error) {
options := config.NewApplicationConfig(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
// Make sure directories exists
if options.ModelPath == "" {
return nil, nil, nil, fmt.Errorf("options.ModelPath cannot be empty")
}
err := os.MkdirAll(options.ModelPath, 0755)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create ModelPath: %q", err)
}
if options.ImageDir != "" {
err := os.MkdirAll(options.ImageDir, 0755)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create ImageDir: %q", err)
}
}
if options.AudioDir != "" {
err := os.MkdirAll(options.AudioDir, 0755)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create AudioDir: %q", err)
}
}
if options.UploadDir != "" {
err := os.MkdirAll(options.UploadDir, 0755)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to create UploadDir: %q", err)
}
}
//
pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...)
cl := config.NewBackendConfigLoader()
ml := model.NewModelLoader(options.ModelPath)
if err := cl.LoadBackendConfigsFromPath(options.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadBackendConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if err := cl.Preload(options.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}
if options.PreloadJSONModels != "" {
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, nil, err
}
}
if options.Debug {
for _, v := range cl.ListBackendConfigs() {
cfg, _ := cl.GetBackendConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
ml.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
ml,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
ml.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
log.Info().Msg("core/startup process completed!")
return cl, ml, options, nil
}

View File

@ -6,6 +6,12 @@ meta {
get { get {
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/backend/monitor url: {{PROTOCOL}}{{HOST}}:{{PORT}}/backend/monitor
body: none body: json
auth: none auth: none
} }
body:json {
{
"model": "{{DEFAULT_MODEL}}"
}
}

View File

@ -4,7 +4,7 @@ import { Document } from "langchain/document";
import { initializeAgentExecutorWithOptions } from "langchain/agents"; import { initializeAgentExecutorWithOptions } from "langchain/agents";
import {Calculator} from "langchain/tools/calculator"; import {Calculator} from "langchain/tools/calculator";
const pathToLocalAi = process.env['OPENAI_API_BASE'] || 'http://api:8080/v1'; const pathToLocalAI = process.env['OPENAI_API_BASE'] || 'http://api:8080/v1';
const fakeApiKey = process.env['OPENAI_API_KEY'] || '-'; const fakeApiKey = process.env['OPENAI_API_KEY'] || '-';
const modelName = process.env['MODEL_NAME'] || 'gpt-3.5-turbo'; const modelName = process.env['MODEL_NAME'] || 'gpt-3.5-turbo';
@ -21,7 +21,7 @@ function getModel(): OpenAIChat {
openAIApiKey: fakeApiKey, openAIApiKey: fakeApiKey,
maxRetries: 2 maxRetries: 2
}, { }, {
basePath: pathToLocalAi, basePath: pathToLocalAI,
apiKey: fakeApiKey, apiKey: fakeApiKey,
}); });
} }

6
go.mod
View File

@ -5,6 +5,7 @@ go 1.21
require ( require (
github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
github.com/fsnotify/fsnotify v1.7.0
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
github.com/go-audio/wav v1.1.0 github.com/go-audio/wav v1.1.0
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
@ -14,7 +15,6 @@ require (
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/hpcloud/tail v1.0.0 github.com/hpcloud/tail v1.0.0
github.com/imdario/mergo v0.3.16 github.com/imdario/mergo v0.3.16
github.com/json-iterator/go v1.1.12
github.com/mholt/archiver/v3 v3.5.1 github.com/mholt/archiver/v3 v3.5.1
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
@ -64,8 +64,6 @@ require (
github.com/klauspost/pgzip v1.2.5 // indirect github.com/klauspost/pgzip v1.2.5 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/nwaples/rardecode v1.1.0 // indirect github.com/nwaples/rardecode v1.1.0 // indirect
github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect
github.com/pkoukk/tiktoken-go v0.1.2 // indirect github.com/pkoukk/tiktoken-go v0.1.2 // indirect
@ -104,7 +102,7 @@ require (
github.com/valyala/tcplisten v1.0.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/net v0.17.0 // indirect golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.12.0 // indirect golang.org/x/tools v0.12.0 // indirect
) )

12
go.sum
View File

@ -26,6 +26,8 @@ github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdf
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ=
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
@ -72,7 +74,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
@ -86,8 +87,6 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw=
github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
@ -117,11 +116,6 @@ github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Cl
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4= github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI= github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI=
@ -278,6 +272,8 @@ golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=

121
main.go
View File

@ -13,11 +13,12 @@ import (
"time" "time"
"github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/backend"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
api "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/core/http"
"github.com/go-skynet/LocalAI/core/startup"
"github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -206,6 +207,12 @@ func main() {
EnvVars: []string{"PRELOAD_BACKEND_ONLY"}, EnvVars: []string{"PRELOAD_BACKEND_ONLY"},
Value: false, Value: false,
}, },
&cli.StringFlag{
Name: "localai-config-dir",
Usage: "Directory to use for the configuration files of LocalAI itself. This is NOT where model files should be placed.",
EnvVars: []string{"LOCALAI_CONFIG_DIR"},
Value: "./configuration",
},
}, },
Description: ` Description: `
LocalAI is a drop-in replacement OpenAI API which runs inference locally. LocalAI is a drop-in replacement OpenAI API which runs inference locally.
@ -224,56 +231,56 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
UsageText: `local-ai [options]`, UsageText: `local-ai [options]`,
Copyright: "Ettore Di Giacinto", Copyright: "Ettore Di Giacinto",
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
opts := []options.AppOption{ opts := []config.AppOption{
options.WithConfigFile(ctx.String("config-file")), config.WithConfigFile(ctx.String("config-file")),
options.WithJSONStringPreload(ctx.String("preload-models")), config.WithJSONStringPreload(ctx.String("preload-models")),
options.WithYAMLConfigPreload(ctx.String("preload-models-config")), config.WithYAMLConfigPreload(ctx.String("preload-models-config")),
options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), config.WithModelPath(ctx.String("models-path")),
options.WithContextSize(ctx.Int("context-size")), config.WithContextSize(ctx.Int("context-size")),
options.WithDebug(ctx.Bool("debug")), config.WithDebug(ctx.Bool("debug")),
options.WithImageDir(ctx.String("image-path")), config.WithImageDir(ctx.String("image-path")),
options.WithAudioDir(ctx.String("audio-path")), config.WithAudioDir(ctx.String("audio-path")),
options.WithUploadDir(ctx.String("upload-path")), config.WithUploadDir(ctx.String("upload-path")),
options.WithF16(ctx.Bool("f16")), config.WithF16(ctx.Bool("f16")),
options.WithStringGalleries(ctx.String("galleries")), config.WithStringGalleries(ctx.String("galleries")),
options.WithModelLibraryURL(ctx.String("remote-library")), config.WithModelLibraryURL(ctx.String("remote-library")),
options.WithDisableMessage(false), config.WithDisableMessage(false),
options.WithCors(ctx.Bool("cors")), config.WithCors(ctx.Bool("cors")),
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), config.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
options.WithThreads(ctx.Int("threads")), config.WithThreads(ctx.Int("threads")),
options.WithBackendAssets(backendAssets), config.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), config.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
options.WithUploadLimitMB(ctx.Int("upload-limit")), config.WithUploadLimitMB(ctx.Int("upload-limit")),
options.WithApiKeys(ctx.StringSlice("api-keys")), config.WithApiKeys(ctx.StringSlice("api-keys")),
options.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...), config.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...),
} }
idleWatchDog := ctx.Bool("enable-watchdog-idle") idleWatchDog := ctx.Bool("enable-watchdog-idle")
busyWatchDog := ctx.Bool("enable-watchdog-busy") busyWatchDog := ctx.Bool("enable-watchdog-busy")
if idleWatchDog || busyWatchDog { if idleWatchDog || busyWatchDog {
opts = append(opts, options.EnableWatchDog) opts = append(opts, config.EnableWatchDog)
if idleWatchDog { if idleWatchDog {
opts = append(opts, options.EnableWatchDogIdleCheck) opts = append(opts, config.EnableWatchDogIdleCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout")) dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
if err != nil { if err != nil {
return err return err
} }
opts = append(opts, options.SetWatchDogIdleTimeout(dur)) opts = append(opts, config.SetWatchDogIdleTimeout(dur))
} }
if busyWatchDog { if busyWatchDog {
opts = append(opts, options.EnableWatchDogBusyCheck) opts = append(opts, config.EnableWatchDogBusyCheck)
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout")) dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
if err != nil { if err != nil {
return err return err
} }
opts = append(opts, options.SetWatchDogBusyTimeout(dur)) opts = append(opts, config.SetWatchDogBusyTimeout(dur))
} }
} }
if ctx.Bool("parallel-requests") { if ctx.Bool("parallel-requests") {
opts = append(opts, options.EnableParallelBackendRequests) opts = append(opts, config.EnableParallelBackendRequests)
} }
if ctx.Bool("single-active-backend") { if ctx.Bool("single-active-backend") {
opts = append(opts, options.EnableSingleBackend) opts = append(opts, config.EnableSingleBackend)
} }
externalgRPC := ctx.StringSlice("external-grpc-backends") externalgRPC := ctx.StringSlice("external-grpc-backends")
@ -281,30 +288,38 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
for _, v := range externalgRPC { for _, v := range externalgRPC {
backend := v[:strings.IndexByte(v, ':')] backend := v[:strings.IndexByte(v, ':')]
uri := v[strings.IndexByte(v, ':')+1:] uri := v[strings.IndexByte(v, ':')+1:]
opts = append(opts, options.WithExternalBackend(backend, uri)) opts = append(opts, config.WithExternalBackend(backend, uri))
} }
if ctx.Bool("autoload-galleries") { if ctx.Bool("autoload-galleries") {
opts = append(opts, options.EnableGalleriesAutoload) opts = append(opts, config.EnableGalleriesAutoload)
} }
if ctx.Bool("preload-backend-only") { if ctx.Bool("preload-backend-only") {
_, _, err := api.Startup(opts...) _, _, _, err := startup.Startup(opts...)
return err return err
} }
metrics, err := metrics.SetupMetrics() cl, ml, options, err := startup.Startup(opts...)
if err != nil { if err != nil {
return err return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
} }
opts = append(opts, options.WithMetrics(metrics))
app, err := api.App(opts...) closeConfigWatcherFn, err := startup.WatchConfigDirectory(ctx.String("localai-config-dir"), options)
defer closeConfigWatcherFn()
if err != nil { if err != nil {
return fmt.Errorf("failed while watching configuration directory %s", ctx.String("localai-config-dir"))
}
appHTTP, err := http.App(cl, ml, options)
if err != nil {
log.Error().Msg("Error during HTTP App constructor")
return err return err
} }
return app.Listen(ctx.String("address")) return appHTTP.Listen(ctx.String("address"))
}, },
Commands: []*cli.Command{ Commands: []*cli.Command{
{ {
@ -402,16 +417,17 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
text := strings.Join(ctx.Args().Slice(), " ") text := strings.Join(ctx.Args().Slice(), " ")
opts := &options.Option{ opts := &config.ApplicationConfig{
Loader: model.NewModelLoader(ctx.String("models-path")), ModelPath: ctx.String("models-path"),
Context: context.Background(), Context: context.Background(),
AudioDir: outputDir, AudioDir: outputDir,
AssetsDestination: ctx.String("backend-assets-path"), AssetsDestination: ctx.String("backend-assets-path"),
} }
ml := model.NewModelLoader(opts.ModelPath)
defer opts.Loader.StopAllGRPC() defer ml.StopAllGRPC()
filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, opts.Loader, opts, config.Config{}) filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, ml, opts, config.BackendConfig{})
if err != nil { if err != nil {
return err return err
} }
@ -464,27 +480,28 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
language := ctx.String("language") language := ctx.String("language")
threads := ctx.Int("threads") threads := ctx.Int("threads")
opts := &options.Option{ opts := &config.ApplicationConfig{
Loader: model.NewModelLoader(ctx.String("models-path")), ModelPath: ctx.String("models-path"),
Context: context.Background(), Context: context.Background(),
AssetsDestination: ctx.String("backend-assets-path"), AssetsDestination: ctx.String("backend-assets-path"),
} }
cl := config.NewConfigLoader() cl := config.NewBackendConfigLoader()
if err := cl.LoadConfigs(ctx.String("models-path")); err != nil { ml := model.NewModelLoader(opts.ModelPath)
if err := cl.LoadBackendConfigsFromPath(ctx.String("models-path")); err != nil {
return err return err
} }
c, exists := cl.GetConfig(modelOption) c, exists := cl.GetBackendConfig(modelOption)
if !exists { if !exists {
return errors.New("model not found") return errors.New("model not found")
} }
c.Threads = threads c.Threads = threads
defer opts.Loader.StopAllGRPC() defer ml.StopAllGRPC()
tr, err := backend.ModelTranscription(filename, language, opts.Loader, c, opts) tr, err := backend.ModelTranscription(filename, language, ml, c, opts)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,83 +0,0 @@
package metrics
import (
"context"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/prometheus"
api "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/metric"
)
type Metrics struct {
meter api.Meter
apiTimeMetric api.Float64Histogram
}
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
// If it does not return an error, make sure to call shutdown for proper cleanup.
func SetupMetrics() (*Metrics, error) {
exporter, err := prometheus.New()
if err != nil {
return nil, err
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
meter := provider.Meter("github.com/go-skynet/LocalAI")
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
if err != nil {
return nil, err
}
return &Metrics{
meter: meter,
apiTimeMetric: apiTimeMetric,
}, nil
}
func MetricsHandler() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metrics *Metrics
}
func APIMiddleware(metrics *Metrics) fiber.Handler {
cfg := apiMiddlewareConfig{
metrics: metrics,
Filter: func(c *fiber.Ctx) bool {
if c.Path() == "/metrics" {
return true
}
return false
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metrics.ObserveAPICall(method, path, elapsed)
return err
}
}
func (m *Metrics) ObserveAPICall(method string, path string, duration float64) {
opts := api.WithAttributes(
attribute.String("method", method),
attribute.String("path", path),
)
m.apiTimeMetric.Record(context.Background(), duration, opts)
}

View File

@ -179,6 +179,10 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 400 {
return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode)
}
// Create parent directory // Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0755) err = os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil { if err != nil {

View File

@ -18,7 +18,6 @@ var _ = Describe("Model test", func() {
defer os.RemoveAll(tempdir) defer os.RemoveAll(tempdir)
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

18
pkg/gallery/op.go Normal file
View File

@ -0,0 +1,18 @@
package gallery
type GalleryOp struct {
Req GalleryModel
Id string
Galleries []Gallery
GalleryName string
}
type GalleryOpStatus struct {
FileName string `json:"file_name"`
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
Progress float64 `json:"progress"`
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
}

View File

@ -3,7 +3,7 @@ package integration_test
import ( import (
"reflect" "reflect"
config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"