refactor(application): introduce application global state (#2072)

* start breaking up the giant channel refactor now that it's better understood - easier to merge bites

Signed-off-by: Dave Lee <dave@gray101.com>

* add concurrency and base64 back in, along with new base64 tests.

Signed-off-by: Dave Lee <dave@gray101.com>

* Automatic rename of whisper.go's Result to TranscriptResult

Signed-off-by: Dave Lee <dave@gray101.com>

* remove pkg/concurrency - significant changes coming in split 2

Signed-off-by: Dave Lee <dave@gray101.com>

* fix comments

Signed-off-by: Dave Lee <dave@gray101.com>

* add list_model service as another low-risk service to get it out of the way

Signed-off-by: Dave Lee <dave@gray101.com>

* split backend config loader into seperate file from the actual config struct. No changes yet, just reduce cognative load with smaller files of logical blocks

Signed-off-by: Dave Lee <dave@gray101.com>

* rename state.go ==> application.go

Signed-off-by: Dave Lee <dave@gray101.com>

* fix lost import?

Signed-off-by: Dave Lee <dave@gray101.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
This commit is contained in:
Dave 2024-04-29 13:42:37 -04:00 committed by GitHub
parent 147440b39b
commit c4f958e11b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 590 additions and 422 deletions

View File

@ -29,8 +29,8 @@ func audioToWav(src, dst string) error {
return nil return nil
} }
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) { func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) {
res := schema.Result{} res := schema.TranscriptionResult{}
dir, err := os.MkdirTemp("", "whisper") dir, err := os.MkdirTemp("", "whisper")
if err != nil { if err != nil {

View File

@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) { func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
} }

39
core/application.go Normal file
View File

@ -0,0 +1,39 @@
package core
import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model"
)
// The purpose of this structure is to hold pointers to all initialized services, to make plumbing easy
// Perhaps a proper DI system is worth it in the future, but for now keep things simple.
type Application struct {
// Application-Level Config
ApplicationConfig *config.ApplicationConfig
// ApplicationState *ApplicationState
// Core Low-Level Services
BackendConfigLoader *config.BackendConfigLoader
ModelLoader *model.ModelLoader
// Backend Services
// EmbeddingsBackendService *backend.EmbeddingsBackendService
// ImageGenerationBackendService *backend.ImageGenerationBackendService
// LLMBackendService *backend.LLMBackendService
// TranscriptionBackendService *backend.TranscriptionBackendService
// TextToSpeechBackendService *backend.TextToSpeechBackendService
// LocalAI System Services
BackendMonitorService *services.BackendMonitorService
GalleryService *services.GalleryService
ListModelsService *services.ListModelsService
LocalAIMetricsService *services.LocalAIMetricsService
// OpenAIService *services.OpenAIService
}
// TODO [NEXT PR?]: Break up ApplicationConfig.
// Migrate over stuff that is not set via config at all - especially runtime stuff
type ApplicationState struct {
}

View File

@ -11,7 +11,7 @@ import (
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
) )
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) { func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
opts := modelOpts(backendConfig, appConfig, []model.Option{ opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend), model.WithBackendString(model.WhisperBackend),

View File

@ -1,23 +1,12 @@
package config package config
import ( import (
"errors"
"fmt"
"io/fs"
"os" "os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/functions" "github.com/go-skynet/LocalAI/pkg/functions"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
"github.com/charmbracelet/glamour"
) )
const ( const (
@ -343,303 +332,3 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
cfg.Debug = &trueV cfg.Debug = &trueV
} }
} }
////// Config Loader ////////
type BackendConfigLoader struct {
configs map[string]BackendConfig
sync.Mutex
}
type LoadOptions struct {
debug bool
threads, ctxSize int
f16 bool
}
func LoadOptionDebug(debug bool) ConfigLoaderOption {
return func(o *LoadOptions) {
o.debug = debug
}
}
func LoadOptionThreads(threads int) ConfigLoaderOption {
return func(o *LoadOptions) {
o.threads = threads
}
}
func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
return func(o *LoadOptions) {
o.ctxSize = ctxSize
}
}
func LoadOptionF16(f16 bool) ConfigLoaderOption {
return func(o *LoadOptions) {
o.f16 = f16
}
}
type ConfigLoaderOption func(*LoadOptions)
func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
for _, l := range options {
l(lo)
}
}
// Load a config file for a model
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cl.LoadBackendConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}
cfg.SetDefaults(opts...)
return cfg, nil
}
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
for _, cc := range *c {
cc.SetDefaults(opts...)
}
return *c, nil
}
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
c.SetDefaults(opts...)
return c, nil
}
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadBackendConfigFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cm.configs[cc.Name] = *cc
}
return nil
}
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
cl.Lock()
defer cl.Unlock()
c, err := ReadBackendConfig(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cl.configs[c.Name] = *c
return nil
}
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
cl.Lock()
defer cl.Unlock()
v, exists := cl.configs[m]
return v, exists
}
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
cl.Lock()
defer cl.Unlock()
var res []BackendConfig
for _, v := range cl.configs {
res = append(res, v)
}
sort.SliceStable(res, func(i, j int) bool {
return res[i].Name < res[j].Name
})
return res
}
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
cl.Lock()
defer cl.Unlock()
var res []string
for k := range cl.configs {
res = append(res, k)
}
return res
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cl *BackendConfigLoader) Preload(modelPath string) error {
cl.Lock()
defer cl.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}
log.Info().Msgf("Preloading models from %s", modelPath)
renderMode := "dark"
if os.Getenv("COLOR") != "" {
renderMode = os.Getenv("COLOR")
}
glamText := func(t string) {
out, err := glamour.Render(t, renderMode)
if err == nil && os.Getenv("NO_COLOR") == "" {
fmt.Println(out)
} else {
fmt.Println(t)
}
}
for i, config := range cl.configs {
// Download files and verify their SHA
for i, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
return err
}
}
// If the model is an URL, expand it, and download the file
if config.IsModelURL() {
modelFileName := config.ModelFileName()
modelURL := downloader.ConvertURL(config.Model)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}
cc := cl.configs[i]
c := &cc
c.PredictionOptions.Model = modelFileName
cl.configs[i] = *c
}
if config.IsMMProjURL() {
modelFileName := config.MMProjFileName()
modelURL := downloader.ConvertURL(config.MMProj)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}
cc := cl.configs[i]
c := &cc
c.MMProj = modelFileName
cl.configs[i] = *c
}
if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
}
if cl.configs[i].Description != "" {
//glamText("**Description**")
glamText(cl.configs[i].Description)
}
if cl.configs[i].Usage != "" {
//glamText("**Usage**")
glamText(cl.configs[i].Usage)
}
}
return nil
}
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
// (non-recursive)
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
return err
}
files = append(files, info)
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") ||
strings.HasPrefix(file.Name(), ".") {
continue
}
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
if err == nil {
cm.configs[c.Name] = *c
}
}
return nil
}

View File

@ -0,0 +1,317 @@
package config
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/charmbracelet/glamour"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/downloader"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
type BackendConfigLoader struct {
configs map[string]BackendConfig
sync.Mutex
}
type LoadOptions struct {
debug bool
threads, ctxSize int
f16 bool
}
func LoadOptionDebug(debug bool) ConfigLoaderOption {
return func(o *LoadOptions) {
o.debug = debug
}
}
func LoadOptionThreads(threads int) ConfigLoaderOption {
return func(o *LoadOptions) {
o.threads = threads
}
}
func LoadOptionContextSize(ctxSize int) ConfigLoaderOption {
return func(o *LoadOptions) {
o.ctxSize = ctxSize
}
}
func LoadOptionF16(f16 bool) ConfigLoaderOption {
return func(o *LoadOptions) {
o.f16 = f16
}
}
type ConfigLoaderOption func(*LoadOptions)
func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
for _, l := range options {
l(lo)
}
}
// Load a config file for a model
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}
cfgExisting, exists := cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cl.LoadBackendConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}
cfg.SetDefaults(opts...)
return cfg, nil
}
func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
for _, cc := range *c {
cc.SetDefaults(opts...)
}
return *c, nil
}
func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
c.SetDefaults(opts...)
return c, nil
}
func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadBackendConfigFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cm.configs[cc.Name] = *cc
}
return nil
}
func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
cl.Lock()
defer cl.Unlock()
c, err := ReadBackendConfig(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cl.configs[c.Name] = *c
return nil
}
func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
cl.Lock()
defer cl.Unlock()
v, exists := cl.configs[m]
return v, exists
}
func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
cl.Lock()
defer cl.Unlock()
var res []BackendConfig
for _, v := range cl.configs {
res = append(res, v)
}
sort.SliceStable(res, func(i, j int) bool {
return res[i].Name < res[j].Name
})
return res
}
func (cl *BackendConfigLoader) ListBackendConfigs() []string {
cl.Lock()
defer cl.Unlock()
var res []string
for k := range cl.configs {
res = append(res, k)
}
return res
}
// Preload prepare models if they are not local but url or huggingface repositories
func (cl *BackendConfigLoader) Preload(modelPath string) error {
cl.Lock()
defer cl.Unlock()
status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}
log.Info().Msgf("Preloading models from %s", modelPath)
renderMode := "dark"
if os.Getenv("COLOR") != "" {
renderMode = os.Getenv("COLOR")
}
glamText := func(t string) {
out, err := glamour.Render(t, renderMode)
if err == nil && os.Getenv("NO_COLOR") == "" {
fmt.Println(out)
} else {
fmt.Println(t)
}
}
for i, config := range cl.configs {
// Download files and verify their SHA
for i, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)
if err := downloader.DownloadFile(file.URI, filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil {
return err
}
}
// If the model is an URL, expand it, and download the file
if config.IsModelURL() {
modelFileName := config.ModelFileName()
modelURL := downloader.ConvertURL(config.Model)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}
cc := cl.configs[i]
c := &cc
c.PredictionOptions.Model = modelFileName
cl.configs[i] = *c
}
if config.IsMMProjURL() {
modelFileName := config.MMProjFileName()
modelURL := downloader.ConvertURL(config.MMProj)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}
cc := cl.configs[i]
c := &cc
c.MMProj = modelFileName
cl.configs[i] = *c
}
if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
}
if cl.configs[i].Description != "" {
//glamText("**Description**")
glamText(cl.configs[i].Description)
}
if cl.configs[i].Usage != "" {
//glamText("**Usage**")
glamText(cl.configs[i].Usage)
}
}
return nil
}
// LoadBackendConfigsFromPath reads all the configurations of the models from a path
// (non-recursive)
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
return err
}
files = append(files, info)
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") ||
strings.HasPrefix(file.Name(), ".") {
continue
}
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
if err == nil {
cm.configs[c.Name] = *c
}
}
return nil
}

View File

@ -1,9 +1,7 @@
package http package http
import ( import (
"encoding/json"
"errors" "errors"
"os"
"strings" "strings"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
@ -124,20 +122,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
return c.Next() return c.Next()
} }
// Check for api_keys.json file
fileContent, err := os.ReadFile("api_keys.json")
if err == nil {
// Parse JSON content from the file
var fileKeys []string
err := json.Unmarshal(fileContent, &fileKeys)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
}
// Add file keys to options.ApiKeys
appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...)
}
if len(appConfig.ApiKeys) == 0 { if len(appConfig.ApiKeys) == 0 {
return c.Next() return c.Next()
} }
@ -174,13 +158,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
app.Use(c) app.Use(c)
} }
// Make sure directories exists
os.MkdirAll(appConfig.ImageDir, 0750)
os.MkdirAll(appConfig.AudioDir, 0750)
os.MkdirAll(appConfig.UploadDir, 0750)
os.MkdirAll(appConfig.ConfigsDir, 0750)
os.MkdirAll(appConfig.ModelPath, 0750)
// Load config jsons // Load config jsons
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)

View File

@ -6,7 +6,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest) input := new(schema.BackendMonitorRequest)
@ -23,7 +23,7 @@ func BackendMonitorEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error
} }
} }
func BackendShutdownEndpoint(bm services.BackendMonitor) func(c *fiber.Ctx) error { func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest) input := new(schema.BackendMonitorRequest)
// Get input data from the request body // Get input data from the request body

View File

@ -1,63 +1,23 @@
package openai package openai
import ( import (
"regexp"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/core/schema"
model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/core/services"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
func ListModelsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { func ListModelsEndpoint(lms *services.ListModelsService) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
models, err := ml.ListModels() // If blank, no filter is applied.
if err != nil {
return err
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []schema.OpenAIModel{}
var filterFn func(name string) bool
filter := c.Query("filter") filter := c.Query("filter")
// If filter is not specified, do not filter the list by model name
if filter == "" {
filterFn = func(_ string) bool { return true }
} else {
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
rxp, err := regexp.Compile(filter)
if err != nil {
return err
}
filterFn = func(name string) bool {
return rxp.MatchString(name)
}
}
// By default, exclude any loose files that are already referenced by a configuration file. // By default, exclude any loose files that are already referenced by a configuration file.
excludeConfigured := c.QueryBool("excludeConfigured", true) excludeConfigured := c.QueryBool("excludeConfigured", true)
// Start with the known configurations dataModels, err := lms.ListModels(filter, excludeConfigured)
for _, c := range cl.GetAllBackendConfigs() { if err != nil {
if excludeConfigured { return err
mm[c.Model] = nil
} }
if filterFn(c.Name) {
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
}
}
// Then iterate through the loose files:
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
}
return c.JSON(struct { return c.JSON(struct {
Object string `json:"object"` Object string `json:"object"`
Data []schema.OpenAIModel `json:"data"` Data []schema.OpenAIModel `json:"data"`

View File

@ -52,9 +52,9 @@ func RegisterLocalAIRoutes(app *fiber.App,
app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint()) app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
// Experimental Backend Statistics Module // Experimental Backend Statistics Module
backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitor)) app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitorService))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitor)) app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitorService))
app.Get("/version", auth, func(c *fiber.Ctx) error { app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct { return c.JSON(struct {

View File

@ -4,6 +4,7 @@ import (
"github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai" "github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai" "github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@ -81,6 +82,7 @@ func RegisterOpenAIRoutes(app *fiber.App,
} }
// models // models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) tmpLMS := services.NewListModelsService(ml, cl, appConfig) // TODO: once createApplication() is fully in use, reference the central instance.
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) app.Get("/v1/models", auth, openai.ListModelsEndpoint(tmpLMS))
app.Get("/models", auth, openai.ListModelsEndpoint(tmpLMS))
} }

View File

@ -10,7 +10,7 @@ type Segment struct {
Tokens []int `json:"tokens"` Tokens []int `json:"tokens"`
} }
type Result struct { type TranscriptionResult struct {
Segments []Segment `json:"segments"` Segments []Segment `json:"segments"`
Text string `json:"text"` Text string `json:"text"`
} }

View File

@ -15,22 +15,22 @@ import (
gopsutil "github.com/shirou/gopsutil/v3/process" gopsutil "github.com/shirou/gopsutil/v3/process"
) )
type BackendMonitor struct { type BackendMonitorService struct {
configLoader *config.BackendConfigLoader backendConfigLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
} }
func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor { func NewBackendMonitorService(modelLoader *model.ModelLoader, configLoader *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *BackendMonitorService {
return BackendMonitor{ return &BackendMonitorService{
configLoader: configLoader,
modelLoader: modelLoader, modelLoader: modelLoader,
backendConfigLoader: configLoader,
options: appConfig, options: appConfig,
} }
} }
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { func (bms BackendMonitorService) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bm.configLoader.GetBackendConfig(modelName) config, exists := bms.backendConfigLoader.GetBackendConfig(modelName)
var backendId string var backendId string
if exists { if exists {
backendId = config.Model backendId = config.Model
@ -46,8 +46,8 @@ func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string
return backendId, nil return backendId, nil
} }
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) { func (bms *BackendMonitorService) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
config, exists := bm.configLoader.GetBackendConfig(model) config, exists := bms.backendConfigLoader.GetBackendConfig(model)
var backend string var backend string
if exists { if exists {
backend = config.Model backend = config.Model
@ -60,7 +60,7 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe
backend = fmt.Sprintf("%s.bin", backend) backend = fmt.Sprintf("%s.bin", backend)
} }
pid, err := bm.modelLoader.GetGRPCPID(backend) pid, err := bms.modelLoader.GetGRPCPID(backend)
if err != nil { if err != nil {
log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid") log.Error().Err(err).Str("model", model).Msg("failed to find GRPC pid")
@ -101,12 +101,12 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe
}, nil }, nil
} }
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { func (bms BackendMonitorService) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bm.getModelLoaderIDFromModelName(modelName) backendId, err := bms.getModelLoaderIDFromModelName(modelName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modelAddr := bm.modelLoader.CheckIsLoaded(backendId) modelAddr := bms.modelLoader.CheckIsLoaded(backendId)
if modelAddr == "" { if modelAddr == "" {
return nil, fmt.Errorf("backend %s is not currently loaded", backendId) return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
} }
@ -114,7 +114,7 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
if rpcErr != nil { if rpcErr != nil {
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
val, slbErr := bm.SampleLocalBackendProcess(backendId) val, slbErr := bms.SampleLocalBackendProcess(backendId)
if slbErr != nil { if slbErr != nil {
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
} }
@ -131,10 +131,10 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse
return status, nil return status, nil
} }
func (bm BackendMonitor) ShutdownModel(modelName string) error { func (bms BackendMonitorService) ShutdownModel(modelName string) error {
backendId, err := bm.getModelLoaderIDFromModelName(modelName) backendId, err := bms.getModelLoaderIDFromModelName(modelName)
if err != nil { if err != nil {
return err return err
} }
return bm.modelLoader.ShutdownModel(backendId) return bms.modelLoader.ShutdownModel(backendId)
} }

View File

@ -0,0 +1,72 @@
package services
import (
"regexp"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/pkg/model"
)
type ListModelsService struct {
bcl *config.BackendConfigLoader
ml *model.ModelLoader
appConfig *config.ApplicationConfig
}
func NewListModelsService(ml *model.ModelLoader, bcl *config.BackendConfigLoader, appConfig *config.ApplicationConfig) *ListModelsService {
return &ListModelsService{
bcl: bcl,
ml: ml,
appConfig: appConfig,
}
}
func (lms *ListModelsService) ListModels(filter string, excludeConfigured bool) ([]schema.OpenAIModel, error) {
models, err := lms.ml.ListModels()
if err != nil {
return nil, err
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []schema.OpenAIModel{}
var filterFn func(name string) bool
// If filter is not specified, do not filter the list by model name
if filter == "" {
filterFn = func(_ string) bool { return true }
} else {
// If filter _IS_ specified, we compile it to a regex which is used to create the filterFn
rxp, err := regexp.Compile(filter)
if err != nil {
return nil, err
}
filterFn = func(name string) bool {
return rxp.MatchString(name)
}
}
// Start with the known configurations
for _, c := range lms.bcl.GetAllBackendConfigs() {
if excludeConfigured {
mm[c.Model] = nil
}
if filterFn(c.Name) {
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
}
}
// Then iterate through the loose files:
for _, m := range models {
// And only adds them if they shouldn't be skipped.
if _, exists := mm[m]; !exists && filterFn(m) {
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
}
return dataModels, nil
}

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/go-skynet/LocalAI/core"
"github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/internal"
@ -133,3 +134,33 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
log.Info().Msg("core/startup process completed!") log.Info().Msg("core/startup process completed!")
return cl, ml, options, nil return cl, ml, options, nil
} }
// In Lieu of a proper DI framework, this function wires up the Application manually.
// This is in core/startup rather than core/state.go to keep package references clean!
func createApplication(appConfig *config.ApplicationConfig) *core.Application {
app := &core.Application{
ApplicationConfig: appConfig,
BackendConfigLoader: config.NewBackendConfigLoader(),
ModelLoader: model.NewModelLoader(appConfig.ModelPath),
}
var err error
// app.EmbeddingsBackendService = backend.NewEmbeddingsBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.ImageGenerationBackendService = backend.NewImageGenerationBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.LLMBackendService = backend.NewLLMBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.TranscriptionBackendService = backend.NewTranscriptionBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.TextToSpeechBackendService = backend.NewTextToSpeechBackendService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.BackendMonitorService = services.NewBackendMonitorService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
app.GalleryService = services.NewGalleryService(app.ApplicationConfig.ModelPath)
app.ListModelsService = services.NewListModelsService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig)
// app.OpenAIService = services.NewOpenAIService(app.ModelLoader, app.BackendConfigLoader, app.ApplicationConfig, app.LLMBackendService)
app.LocalAIMetricsService, err = services.NewLocalAIMetricsService()
if err != nil {
log.Error().Err(err).Msg("encountered an error initializing metrics service, startup will continue but metrics will not be tracked.")
}
return app
}

View File

@ -41,7 +41,7 @@ type Backend interface {
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error)

View File

@ -53,8 +53,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented") return fmt.Errorf("unimplemented")
} }
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return schema.Result{}, fmt.Errorf("unimplemented") return schema.TranscriptionResult{}, fmt.Errorf("unimplemented")
} }
func (llm *Base) TTS(*pb.TTSRequest) error { func (llm *Base) TTS(*pb.TTSRequest) error {

View File

@ -210,7 +210,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
return client.TTS(ctx, in, opts...) return client.TTS(ctx, in, opts...)
} }
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
if !c.parallel { if !c.parallel {
c.opMutex.Lock() c.opMutex.Lock()
defer c.opMutex.Unlock() defer c.opMutex.Unlock()
@ -231,7 +231,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
if err != nil { if err != nil {
return nil, err return nil, err
} }
tresult := &schema.Result{} tresult := &schema.TranscriptionResult{}
for _, s := range res.Segments { for _, s := range res.Segments {
tks := []int{} tks := []int{}
for _, t := range s.Tokens { for _, t := range s.Tokens {

View File

@ -53,12 +53,12 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
return e.s.TTS(ctx, in) return e.s.TTS(ctx, in)
} }
func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { func (e *embedBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
r, err := e.s.AudioTranscription(ctx, in) r, err := e.s.AudioTranscription(ctx, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tr := &schema.Result{} tr := &schema.TranscriptionResult{}
for _, s := range r.Segments { for _, s := range r.Segments {
var tks []int var tks []int
for _, t := range s.Tokens { for _, t := range s.Tokens {

View File

@ -15,7 +15,7 @@ type LLM interface {
Load(*pb.ModelOptions) error Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error) Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error)
TTS(*pb.TTSRequest) error TTS(*pb.TTSRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error) Status() (pb.StatusResponse, error)

50
pkg/utils/base64.go Normal file
View File

@ -0,0 +1,50 @@
package utils
import (
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"time"
)
var base64DownloadClient http.Client = http.Client{
Timeout: 30 * time.Second,
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
// This may look weird down in pkg/utils while it is currently only used in core/config
//
// but I believe it may be useful for MQTT as well in the near future, so I'm
// extracting it while I'm thinking of it.
func GetImageURLAsBase64(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := base64DownloadClient.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}

31
pkg/utils/base64_test.go Normal file
View File

@ -0,0 +1,31 @@
package utils_test
import (
. "github.com/go-skynet/LocalAI/pkg/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("utils/base64 tests", func() {
It("GetImageURLAsBase64 can strip data url prefixes", func() {
// This one doesn't actually _care_ that it's base64, so feed "bad" data in this test in order to catch a change in that behavior for informational purposes.
input := ""
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).To(Equal("FOO"))
})
It("GetImageURLAsBase64 returns an error for bogus data", func() {
input := "FOO"
b64, err := GetImageURLAsBase64(input)
Expect(b64).To(Equal(""))
Expect(err).ToNot(BeNil())
Expect(err).To(MatchError("not valid string"))
})
It("GetImageURLAsBase64 can actually download images and calculates something", func() {
// This test doesn't actually _check_ the results at this time, which is bad, but there wasn't a test at all before...
input := "https://upload.wikimedia.org/wikipedia/en/2/29/Wargames.jpg"
b64, err := GetImageURLAsBase64(input)
Expect(err).To(BeNil())
Expect(b64).ToNot(BeNil())
})
})