mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
[Refactor]: Core/API Split (#1506)
Refactors api folder to core, creates firm split between backend code and api frontend.
This commit is contained in:
parent
bcf02449b3
commit
ab7b4d5ee9
4
.gitignore
vendored
4
.gitignore
vendored
@ -19,8 +19,8 @@ LocalAI
|
||||
local-ai
|
||||
# prevent above rules from omitting the helm chart
|
||||
!charts/*
|
||||
# prevent above rules from omitting the api/localai folder
|
||||
!api/localai
|
||||
# prevent above rules from omitting the core/**/localai folder
|
||||
!core/**/localai
|
||||
|
||||
# Ignore models
|
||||
models/*
|
||||
|
@ -88,7 +88,7 @@ ENV NVIDIA_VISIBLE_DEVICES=all
|
||||
WORKDIR /build
|
||||
|
||||
COPY . .
|
||||
COPY .git .
|
||||
COPY .git/ .git/
|
||||
RUN make prepare
|
||||
|
||||
# stablediffusion does not tolerate a newer version of abseil, build it first
|
||||
|
302
api/api.go
302
api/api.go
@ -1,302 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/localai"
|
||||
"github.com/go-skynet/LocalAI/api/openai"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/metrics"
|
||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
|
||||
options := options.NewOptions(opts...)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
if options.Debug {
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
|
||||
modelPath := options.Loader.ModelPath
|
||||
if len(options.ModelsURL) > 0 {
|
||||
for _, url := range options.ModelsURL {
|
||||
if utils.LooksLikeURL(url) {
|
||||
// md5 of model name
|
||||
md5Name := utils.MD5(url)
|
||||
|
||||
// check if file exists
|
||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
||||
err := utils.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Msgf("error loading model: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cl := config.NewConfigLoader()
|
||||
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
|
||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
|
||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := cl.Preload(options.Loader.ModelPath); err != nil {
|
||||
log.Error().Msgf("error downloading models: %s", err.Error())
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
for _, v := range cl.ListConfigs() {
|
||||
cfg, _ := cl.GetConfig(v)
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
if options.AssetsDestination != "" {
|
||||
// Extract files from the embedded FS
|
||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||
if err != nil {
|
||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
options.Loader.StopAllGRPC()
|
||||
}()
|
||||
|
||||
if options.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
options.Loader,
|
||||
options.WatchDogBusyTimeout,
|
||||
options.WatchDogIdleTimeout,
|
||||
options.WatchDogBusy,
|
||||
options.WatchDogIdle)
|
||||
options.Loader.SetWatchDog(wd)
|
||||
go wd.Run()
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
wd.Shutdown()
|
||||
}()
|
||||
}
|
||||
|
||||
return options, cl, nil
|
||||
}
|
||||
|
||||
func App(opts ...options.AppOption) (*fiber.App, error) {
|
||||
|
||||
options, cl, err := Startup(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||
}
|
||||
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: options.DisableMessage,
|
||||
// Override default error handler
|
||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
code := fiber.StatusInternalServerError
|
||||
|
||||
// Retrieve the custom status code if it's a *fiber.Error
|
||||
var e *fiber.Error
|
||||
if errors.As(err, &e) {
|
||||
code = e.Code
|
||||
}
|
||||
|
||||
// Send custom error page
|
||||
return ctx.Status(code).JSON(
|
||||
schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
},
|
||||
)
|
||||
},
|
||||
})
|
||||
|
||||
if options.Debug {
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||
}))
|
||||
}
|
||||
|
||||
// Default middleware config
|
||||
app.Use(recover.New())
|
||||
if options.Metrics != nil {
|
||||
app.Use(metrics.APIMiddleware(options.Metrics))
|
||||
}
|
||||
|
||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||
auth := func(c *fiber.Ctx) error {
|
||||
if len(options.ApiKeys) == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Check for api_keys.json file
|
||||
fileContent, err := os.ReadFile("api_keys.json")
|
||||
if err == nil {
|
||||
// Parse JSON content from the file
|
||||
var fileKeys []string
|
||||
err := json.Unmarshal(fileContent, &fileKeys)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
|
||||
}
|
||||
|
||||
// Add file keys to options.ApiKeys
|
||||
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
||||
}
|
||||
|
||||
if len(options.ApiKeys) == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||
}
|
||||
authHeaderParts := strings.Split(authHeader, " ")
|
||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||
}
|
||||
|
||||
apiKey := authHeaderParts[1]
|
||||
for _, key := range options.ApiKeys {
|
||||
if apiKey == key {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||
|
||||
}
|
||||
|
||||
if options.CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if options.CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
||||
}
|
||||
|
||||
app.Use(c)
|
||||
}
|
||||
|
||||
// LocalAI API endpoints
|
||||
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
|
||||
galleryService.Start(options.Context, cl)
|
||||
|
||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
Version string `json:"version"`
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
|
||||
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
||||
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
||||
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
|
||||
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
|
||||
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
|
||||
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
|
||||
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
|
||||
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
|
||||
|
||||
if options.ImageDir != "" {
|
||||
app.Static("/generated-images", options.ImageDir)
|
||||
}
|
||||
|
||||
if options.AudioDir != "" {
|
||||
app.Static("/generated-audio", options.AudioDir)
|
||||
}
|
||||
|
||||
ok := func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(200)
|
||||
}
|
||||
|
||||
// Kubernetes health checks
|
||||
app.Get("/healthz", ok)
|
||||
app.Get("/readyz", ok)
|
||||
|
||||
// Experimental Backend Statistics Module
|
||||
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
|
||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
||||
|
||||
// models
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
|
||||
|
||||
app.Get("/metrics", metrics.MetricsHandler())
|
||||
|
||||
return app, nil
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithContext(o.Context),
|
||||
model.WithModel(c.Model),
|
||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||
SchedulerType: c.Diffusers.SchedulerType,
|
||||
PipelineType: c.Diffusers.PipelineType,
|
||||
CFGScale: c.Diffusers.CFGScale,
|
||||
LoraAdapter: c.LoraAdapter,
|
||||
LoraScale: c.LoraScale,
|
||||
LoraBase: c.LoraBase,
|
||||
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||
CLIPModel: c.Diffusers.ClipModel,
|
||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||
ControlNet: c.Diffusers.ControlNet,
|
||||
}),
|
||||
})
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fn := func() error {
|
||||
_, err := inferenceModel.GenerateImage(
|
||||
o.Context,
|
||||
&proto.GenerateImageRequest{
|
||||
Height: int32(height),
|
||||
Width: int32(width),
|
||||
Mode: int32(mode),
|
||||
Step: int32(step),
|
||||
Seed: int32(seed),
|
||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||
PositivePrompt: positive_prompt,
|
||||
NegativePrompt: negative_prompt,
|
||||
Dst: dst,
|
||||
Src: src,
|
||||
EnableParameters: c.Diffusers.EnableParameters,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
@ -1,167 +0,0 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
Prompt int
|
||||
Completion int
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
||||
var inferenceModel *grpc.Client
|
||||
var err error
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
})
|
||||
|
||||
if c.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
}
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
fn := func() (LLMResponse, error) {
|
||||
opts := gRPCPredictOpts(c, loader.ModelPath)
|
||||
opts.Prompt = s
|
||||
opts.Images = images
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
// check the per-model feature flag for usage, since tokenCallback may have a cost.
|
||||
// Defaults to off as for now it is still experimental
|
||||
if c.FeatureFlag.Enabled("usage") {
|
||||
userTokenCallback := tokenCallback
|
||||
if userTokenCallback == nil {
|
||||
userTokenCallback = func(token string, usage TokenUsage) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
|
||||
if pErr == nil && promptInfo.Length > 0 {
|
||||
tokenUsage.Prompt = int(promptInfo.Length)
|
||||
}
|
||||
|
||||
tokenCallback = func(token string, usage TokenUsage) bool {
|
||||
tokenUsage.Completion++
|
||||
return userTokenCallback(token, tokenUsage)
|
||||
}
|
||||
}
|
||||
|
||||
if tokenCallback != nil {
|
||||
ss := ""
|
||||
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
|
||||
partialRune = append(partialRune, chars...)
|
||||
|
||||
for len(partialRune) > 0 {
|
||||
r, size := utf8.DecodeRune(partialRune)
|
||||
if r == utf8.RuneError {
|
||||
// incomplete rune, wait for more bytes
|
||||
break
|
||||
}
|
||||
|
||||
tokenCallback(string(r), tokenUsage)
|
||||
ss += string(r)
|
||||
|
||||
partialRune = partialRune[size:]
|
||||
}
|
||||
})
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
reply, err := inferenceModel.Predict(ctx, opts)
|
||||
if err != nil {
|
||||
return LLMResponse{}, err
|
||||
}
|
||||
return LLMResponse{
|
||||
Response: string(reply.Message),
|
||||
Usage: tokenUsage,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||
var mu sync.Mutex = sync.Mutex{}
|
||||
|
||||
func Finetune(config config.Config, input, prediction string) string {
|
||||
if config.Echo {
|
||||
prediction = input + prediction
|
||||
}
|
||||
|
||||
for _, c := range config.Cutstrings {
|
||||
mu.Lock()
|
||||
reg, ok := cutstrings[c]
|
||||
if !ok {
|
||||
cutstrings[c] = regexp.MustCompile(c)
|
||||
reg = cutstrings[c]
|
||||
}
|
||||
mu.Unlock()
|
||||
prediction = reg.ReplaceAllString(prediction, "")
|
||||
}
|
||||
|
||||
for _, c := range config.TrimSpace {
|
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
||||
}
|
||||
|
||||
for _, c := range config.TrimSuffix {
|
||||
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
|
||||
}
|
||||
return prediction
|
||||
}
|
@ -1,39 +0,0 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) {
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
})
|
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return nil, fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
Threads: uint32(c.Threads),
|
||||
})
|
||||
}
|
@ -1,162 +0,0 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
type BackendMonitorRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
}
|
||||
|
||||
type BackendMonitorResponse struct {
|
||||
MemoryInfo *gopsutil.MemoryInfoStat
|
||||
MemoryPercent float32
|
||||
CPUPercent float64
|
||||
}
|
||||
|
||||
type BackendMonitor struct {
|
||||
configLoader *config.ConfigLoader
|
||||
options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||
}
|
||||
|
||||
func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor {
|
||||
return BackendMonitor{
|
||||
configLoader: configLoader,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) {
|
||||
config, exists := bm.configLoader.GetConfig(model)
|
||||
var backend string
|
||||
if exists {
|
||||
backend = config.Model
|
||||
} else {
|
||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||
backend = model
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(backend, ".bin") {
|
||||
backend = fmt.Sprintf("%s.bin", backend)
|
||||
}
|
||||
|
||||
pid, err := bm.options.Loader.GetGRPCPID(backend)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
||||
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memInfo, err := backendProcess.MemoryInfo()
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memPercent, err := backendProcess.MemoryPercent()
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuPercent, err := backendProcess.CPUPercent()
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &BackendMonitorResponse{
|
||||
MemoryInfo: memInfo,
|
||||
MemoryPercent: memPercent,
|
||||
CPUPercent: cpuPercent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) {
|
||||
input := new(BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
config, exists := bm.configLoader.GetConfig(input.Model)
|
||||
var backendId string
|
||||
if exists {
|
||||
backendId = config.Model
|
||||
} else {
|
||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||
backendId = input.Model
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(backendId, ".bin") {
|
||||
backendId = fmt.Sprintf("%s.bin", backendId)
|
||||
}
|
||||
|
||||
return backendId, nil
|
||||
}
|
||||
|
||||
func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model := bm.options.Loader.CheckIsLoaded(backendId)
|
||||
if model == "" {
|
||||
return fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||
}
|
||||
|
||||
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
|
||||
if rpcErr != nil {
|
||||
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
||||
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
||||
if slbErr != nil {
|
||||
return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
||||
}
|
||||
return c.JSON(proto.StatusResponse{
|
||||
State: proto.StatusResponse_ERROR,
|
||||
Memory: &proto.MemoryUsageData{
|
||||
Total: val.MemoryInfo.VMS,
|
||||
Breakdown: map[string]uint64{
|
||||
"gopsutil-RSS": val.MemoryInfo.RSS,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(status)
|
||||
}
|
||||
}
|
||||
|
||||
func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendId, err := bm.getModelLoaderIDFromCtx(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.options.Loader.ShutdownModel(backendId)
|
||||
}
|
||||
}
|
@ -1,326 +0,0 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
json "github.com/json-iterator/go"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type galleryOp struct {
|
||||
req gallery.GalleryModel
|
||||
id string
|
||||
galleries []gallery.Gallery
|
||||
galleryName string
|
||||
}
|
||||
|
||||
type galleryOpStatus struct {
|
||||
FileName string `json:"file_name"`
|
||||
Error error `json:"error"`
|
||||
Processed bool `json:"processed"`
|
||||
Message string `json:"message"`
|
||||
Progress float64 `json:"progress"`
|
||||
TotalFileSize string `json:"file_size"`
|
||||
DownloadedFileSize string `json:"downloaded_size"`
|
||||
}
|
||||
|
||||
type galleryApplier struct {
|
||||
modelPath string
|
||||
sync.Mutex
|
||||
C chan galleryOp
|
||||
statuses map[string]*galleryOpStatus
|
||||
}
|
||||
|
||||
func NewGalleryService(modelPath string) *galleryApplier {
|
||||
return &galleryApplier{
|
||||
modelPath: modelPath,
|
||||
C: make(chan galleryOp),
|
||||
statuses: make(map[string]*galleryOpStatus),
|
||||
}
|
||||
}
|
||||
|
||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
||||
|
||||
config, err := gallery.GetGalleryConfigFromURL(req.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
|
||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
||||
}
|
||||
|
||||
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
g.statuses[s] = op
|
||||
}
|
||||
|
||||
func (g *galleryApplier) getStatus(s string) *galleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses[s]
|
||||
}
|
||||
|
||||
func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses
|
||||
}
|
||||
|
||||
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
case op := <-g.C:
|
||||
utils.ResetDownloadTimers()
|
||||
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
|
||||
|
||||
// updates the status with an error
|
||||
updateError := func(e error) {
|
||||
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
||||
}
|
||||
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
var err error
|
||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||
if op.galleryName != "" {
|
||||
if strings.Contains(op.galleryName, "@") {
|
||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
|
||||
}
|
||||
} else {
|
||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
updateError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Reload models
|
||||
err = cm.LoadConfigs(g.modelPath)
|
||||
if err != nil {
|
||||
updateError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = cm.Preload(g.modelPath)
|
||||
if err != nil {
|
||||
updateError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type galleryModel struct {
|
||||
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||
var err error
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
if strings.Contains(r.ID, "@") {
|
||||
err = gallery.InstallModelFromGallery(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||
dat, err := os.ReadFile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var requests []galleryModel
|
||||
|
||||
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
||||
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error {
|
||||
var requests []galleryModel
|
||||
err := json.Unmarshal([]byte(s), &requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
||||
|
||||
/// Endpoint Service
|
||||
|
||||
type ModelGalleryService struct {
|
||||
galleries []gallery.Gallery
|
||||
modelPath string
|
||||
galleryApplier *galleryApplier
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
|
||||
return ModelGalleryService{
|
||||
galleries: galleries,
|
||||
modelPath: modelPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.galleryApplier.getAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.C <- galleryOp{
|
||||
req: input.GalleryModel,
|
||||
id: uuid.String(),
|
||||
galleryName: input.ID,
|
||||
galleries: mgs.galleries,
|
||||
}
|
||||
return c.JSON(struct {
|
||||
ID string `json:"uuid"`
|
||||
StatusURL string `json:"status"`
|
||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("Models found from galleries: %+v", models)
|
||||
for _, m := range models {
|
||||
log.Debug().Msgf("Model found from galleries: %+v", m)
|
||||
}
|
||||
dat, err := json.Marshal(models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(gallery.Gallery)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
}) {
|
||||
return fmt.Errorf("%s already exists", input.Name)
|
||||
}
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
||||
mgs.galleries = append(mgs.galleries, *input)
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(gallery.Gallery)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
}) {
|
||||
return fmt.Errorf("%s is not currently registered", input.Name)
|
||||
}
|
||||
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
})
|
||||
return c.Send(nil)
|
||||
}
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type TTSRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Input string `json:"input" yaml:"input"`
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
}
|
||||
|
||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(TTSRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, o.Loader, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
}
|
||||
}
|
@ -1,399 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
emptyMessage := ""
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
return func(c *fiber.Ctx) error {
|
||||
processFunctions := false
|
||||
funcs := grammar.Functions{}
|
||||
modelFile, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer"
|
||||
noActionDescription := "use this action to answer without performing any action"
|
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" {
|
||||
noActionName = config.FunctionsConfig.NoActionFunctionName
|
||||
}
|
||||
if config.FunctionsConfig.NoActionDescriptionName != "" {
|
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
|
||||
processFunctions = true
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
Name: noActionName,
|
||||
Description: noActionDescription,
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to reply the user with",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...)
|
||||
if !config.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
funcs = funcs.Select(config.FunctionToCall())
|
||||
}
|
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
config.Grammar = jsStruct.Grammar("")
|
||||
} else if input.JSONFunctionGrammarObject != nil {
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
||||
}
|
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions
|
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config)
|
||||
|
||||
var predInput string
|
||||
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range input.Messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := config.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: i.StringContent,
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
if toStream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if processFunctions {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
|
||||
if toStream {
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, o.Loader, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
usage := &schema.OpenAIUsage{}
|
||||
|
||||
for ev := range responses {
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
break
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: &emptyMessage},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
||||
if processFunctions {
|
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
s = utils.EscapeNewLines(s)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"]
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args)
|
||||
|
||||
ss["arguments"] = string(d)
|
||||
ss["name"] = func_name
|
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName {
|
||||
log.Debug().Msgf("nothing to do, computing a reply")
|
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(d), &arguments)
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
case string:
|
||||
if message != "" {
|
||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||
message = backend.Finetune(*config, predInput, message)
|
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||
|
||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = ""
|
||||
images := []string{}
|
||||
for _, m := range input.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
|
||||
} else {
|
||||
// otherwise reply with the function call
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "function_call",
|
||||
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
@ -1,199 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Text: s,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("`input`: %+v", input)
|
||||
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Completion != "" {
|
||||
templateFile = config.TemplateConfig.Completion
|
||||
}
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: predInput,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
}
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
go process(predInput, input, config, o.Loader, responses)
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
for ev := range responses {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
|
||||
totalTokenUsage := backend.TokenUsage{}
|
||||
|
||||
for k, i := range config.PromptStrings {
|
||||
if templateFile != "" {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(
|
||||
input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
@ -1,94 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Edit != "" {
|
||||
templateFile = config.TemplateConfig.Edit
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
totalTokenUsage := backend.TokenUsage{}
|
||||
|
||||
for _, i := range config.InputStrings {
|
||||
if templateFile != "" {
|
||||
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "edit",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
@ -1,78 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
model, input, err := readInput(c, o, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
items := []schema.Item{}
|
||||
|
||||
for i, s := range config.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
for i, s := range config.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
@ -1,239 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create the file
|
||||
out, err := os.CreateTemp("", "image")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
return out.Name(), err
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/*
|
||||
*
|
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "A cute baby sea otter",
|
||||
"n": 1,
|
||||
"size": "512x512"
|
||||
}'
|
||||
|
||||
*
|
||||
*/
|
||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readInput(c, o, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if m == "" {
|
||||
m = model.StableDiffusionBackend
|
||||
}
|
||||
log.Debug().Msgf("Loading model: %+v", m)
|
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
src := ""
|
||||
if input.File != "" {
|
||||
|
||||
fileData := []byte{}
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||
out, err := downloadFile(input.File)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed downloading file:%w", err)
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading file:%w", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
src = outputFile.Name()
|
||||
defer os.RemoveAll(src)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
switch config.Backend {
|
||||
case "stablediffusion":
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
case "tinydream":
|
||||
config.Backend = model.TinyDreamBackend
|
||||
case "":
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
}
|
||||
|
||||
sizeParts := strings.Split(input.Size, "x")
|
||||
if len(sizeParts) != 2 {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
width, err := strconv.Atoi(sizeParts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
height, err := strconv.Atoi(sizeParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Invalid value for 'size'")
|
||||
}
|
||||
|
||||
b64JSON := false
|
||||
if input.ResponseFormat.Type == "b64_json" {
|
||||
b64JSON = true
|
||||
}
|
||||
// src and clip_skip
|
||||
var result []schema.Item
|
||||
for _, i := range config.PromptStrings {
|
||||
n := input.N
|
||||
if input.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
prompts := strings.Split(i, "|")
|
||||
positive_prompt := prompts[0]
|
||||
negative_prompt := ""
|
||||
if len(prompts) > 1 {
|
||||
negative_prompt = prompts[1]
|
||||
}
|
||||
|
||||
mode := 0
|
||||
step := config.Step
|
||||
if step == 0 {
|
||||
step = 15
|
||||
}
|
||||
|
||||
if input.Mode != 0 {
|
||||
mode = input.Mode
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
step = input.Step
|
||||
}
|
||||
|
||||
tempDir := ""
|
||||
if !b64JSON {
|
||||
tempDir = o.ImageDir
|
||||
}
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
outputFile.Close()
|
||||
output := outputFile.Name() + ".png"
|
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item := &schema.Item{}
|
||||
|
||||
if b64JSON {
|
||||
defer os.RemoveAll(output)
|
||||
data, err := os.ReadFile(output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-images/" + base
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Data: result,
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ComputeChoices(
|
||||
req *schema.OpenAIRequest,
|
||||
predInput string,
|
||||
config *config.Config,
|
||||
o *options.Option,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
||||
n := req.N // number of completions to return
|
||||
result := []schema.Choice{}
|
||||
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range req.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
|
||||
tokenUsage := backend.TokenUsage{}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
|
||||
tokenUsage.Prompt += prediction.Usage.Prompt
|
||||
tokenUsage.Completion += prediction.Usage.Completion
|
||||
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, tokenUsage, err
|
||||
}
|
@ -1,336 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
options "github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *schema.OpenAIRequest, error) {
|
||||
loader := o.Loader
|
||||
input := new(schema.OpenAIRequest)
|
||||
ctx, cancel := context.WithCancel(o.Context)
|
||||
input.Context = ctx
|
||||
input.Cancel = cancel
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
modelFile := input.Model
|
||||
|
||||
if c.Params("model") != "" {
|
||||
modelFile = c.Params("model")
|
||||
}
|
||||
|
||||
received, _ := json.Marshal(input)
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel {
|
||||
models, _ := loader.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelFile = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", nil, fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelFile = bearer
|
||||
}
|
||||
return modelFile, input, nil
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func getBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != 0 {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != 0 {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
index := 0
|
||||
for i, m := range input.Messages {
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
for _, pp := range c {
|
||||
if pp.Type == "text" {
|
||||
input.Messages[i].StringContent = pp.Text
|
||||
} else if pp.Type == "image_url" {
|
||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||
base64, err := getBase64Image(pp.ImageURL.URL)
|
||||
if err == nil {
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
// set a placeholder for each image
|
||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
||||
index++
|
||||
} else {
|
||||
fmt.Print("Failed encoding image", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
config.F16 = input.F16
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.Mirostat != 0 {
|
||||
config.LLMConfig.Mirostat = input.Mirostat
|
||||
}
|
||||
|
||||
if input.MirostatETA != 0 {
|
||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
||||
}
|
||||
|
||||
if input.MirostatTAU != 0 {
|
||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
||||
}
|
||||
|
||||
if input.TypicalP != 0 {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) {
|
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
|
||||
|
||||
var cfg *config.Config
|
||||
|
||||
defaults := func() {
|
||||
cfg = config.DefaultConfig(modelFile)
|
||||
cfg.ContextSize = ctx
|
||||
cfg.Threads = threads
|
||||
cfg.F16 = f16
|
||||
cfg.Debug = debug
|
||||
}
|
||||
|
||||
cfgExisting, exists := cm.GetConfig(modelFile)
|
||||
if !exists {
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
cfgExisting, exists = cm.GetConfig(modelFile)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
cfg = &cfgExisting
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(cfg, input)
|
||||
|
||||
// Don't allow 0 as setting
|
||||
if cfg.Threads == 0 {
|
||||
if threads != 0 {
|
||||
cfg.Threads = threads
|
||||
} else {
|
||||
cfg.Threads = 4
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug {
|
||||
cfg.Debug = true
|
||||
}
|
||||
|
||||
return cfg, input, nil
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
m, input, err := readInput(c, o, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename))
|
||||
dstFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil {
|
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr)
|
||||
}
|
||||
}
|
@ -8,7 +8,7 @@ import (
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
)
|
||||
|
||||
func sh(c string) (string, error) {
|
||||
@ -29,8 +29,8 @@ func audioToWav(src, dst string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.Result, error) {
|
||||
res := schema.Result{}
|
||||
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.WhisperResult, error) {
|
||||
res := schema.WhisperResult{}
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
@ -90,7 +90,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
|
||||
tokens = append(tokens, t.Id)
|
||||
}
|
||||
|
||||
segment := schema.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
|
||||
segment := schema.WhisperSegment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
|
||||
res.Segments = append(res.Segments, segment)
|
||||
|
||||
res.Text += s.Text
|
||||
|
@ -4,9 +4,9 @@ package main
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
)
|
||||
|
||||
type Whisper struct {
|
||||
@ -21,6 +21,6 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.Result, error) {
|
||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.WhisperResult, error) {
|
||||
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
|
||||
}
|
||||
|
0
config/.keep
Normal file
0
config/.keep
Normal file
@ -2,14 +2,17 @@ package backend
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() ([]float32, error), error) {
|
||||
if !c.Embeddings {
|
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
|
||||
}
|
||||
@ -27,6 +30,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||
})
|
||||
|
||||
if c.Backend == "" {
|
||||
@ -90,3 +94,51 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
|
||||
return embeds, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func EmbeddingOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
items := []schema.Item{}
|
||||
|
||||
for i, s := range config.InputToken {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding("", s, ml, *config, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
for i, s := range config.InputStrings {
|
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding(s, []int{}, ml, *config, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddings, err := embedFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
return &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items,
|
||||
Object: "list",
|
||||
}, nil
|
||||
}
|
210
core/backend/image.go
Normal file
210
core/backend/image.go
Normal file
@ -0,0 +1,210 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (func() error, error) {
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithBackendString(c.Backend),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithContext(o.Context),
|
||||
model.WithModel(c.Model),
|
||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||
SchedulerType: c.Diffusers.SchedulerType,
|
||||
PipelineType: c.Diffusers.PipelineType,
|
||||
CFGScale: c.Diffusers.CFGScale,
|
||||
LoraAdapter: c.LoraAdapter,
|
||||
LoraScale: c.LoraScale,
|
||||
LoraBase: c.LoraBase,
|
||||
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||
CLIPModel: c.Diffusers.ClipModel,
|
||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||
ControlNet: c.Diffusers.ControlNet,
|
||||
}),
|
||||
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||
})
|
||||
|
||||
inferenceModel, err := loader.BackendLoader(
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fn := func() error {
|
||||
_, err := inferenceModel.GenerateImage(
|
||||
o.Context,
|
||||
&proto.GenerateImageRequest{
|
||||
Height: int32(height),
|
||||
Width: int32(width),
|
||||
Mode: int32(mode),
|
||||
Step: int32(step),
|
||||
Seed: int32(seed),
|
||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||
PositivePrompt: positive_prompt,
|
||||
NegativePrompt: negative_prompt,
|
||||
Dst: dst,
|
||||
Src: src,
|
||||
EnableParameters: c.Diffusers.EnableParameters,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
func ImageGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
if modelName == "" {
|
||||
modelName = model.StableDiffusionBackend
|
||||
}
|
||||
log.Debug().Msgf("Loading model: %+v", modelName)
|
||||
|
||||
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading parameters from request: %w", err)
|
||||
}
|
||||
|
||||
src := ""
|
||||
if input.File != "" {
|
||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||
src, err = utils.CreateTempFileFromUrl(input.File, "", "image-src")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed downloading file:%w", err)
|
||||
}
|
||||
} else {
|
||||
src, err = utils.CreateTempFileFromBase64(input.File, "", "base64-image-src")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating temporary image source file: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
switch config.Backend {
|
||||
case "stablediffusion":
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
case "tinydream":
|
||||
config.Backend = model.TinyDreamBackend
|
||||
case "":
|
||||
config.Backend = model.StableDiffusionBackend
|
||||
}
|
||||
|
||||
sizeParts := strings.Split(input.Size, "x")
|
||||
if len(sizeParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid value for 'size'")
|
||||
}
|
||||
width, err := strconv.Atoi(sizeParts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid value for 'size'")
|
||||
}
|
||||
height, err := strconv.Atoi(sizeParts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid value for 'size'")
|
||||
}
|
||||
|
||||
b64JSON := false
|
||||
if input.ResponseFormat.Type == "b64_json" {
|
||||
b64JSON = true
|
||||
}
|
||||
// src and clip_skip
|
||||
var result []schema.Item
|
||||
for _, i := range config.PromptStrings {
|
||||
n := input.N
|
||||
if input.N == 0 {
|
||||
n = 1
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
prompts := strings.Split(i, "|")
|
||||
positive_prompt := prompts[0]
|
||||
negative_prompt := ""
|
||||
if len(prompts) > 1 {
|
||||
negative_prompt = prompts[1]
|
||||
}
|
||||
|
||||
mode := 0
|
||||
step := config.Step
|
||||
if step == 0 {
|
||||
step = 15
|
||||
}
|
||||
|
||||
if input.Mode != 0 {
|
||||
mode = input.Mode
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
step = input.Step
|
||||
}
|
||||
|
||||
tempDir := ""
|
||||
if !b64JSON {
|
||||
tempDir = startupOptions.ImageDir
|
||||
}
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputFile.Close()
|
||||
output := outputFile.Name() + ".png"
|
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := fn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
item := &schema.Item{}
|
||||
|
||||
if b64JSON {
|
||||
defer os.RemoveAll(output)
|
||||
data, err := os.ReadFile(output)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = path.Join(startupOptions.ImageDir, base)
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
return &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Data: result,
|
||||
}, nil
|
||||
}
|
861
core/backend/llm.go
Normal file
861
core/backend/llm.go
Normal file
@ -0,0 +1,861 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
////////// TYPES //////////////
|
||||
|
||||
type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
}
|
||||
|
||||
// TODO: Test removing this and using the variant in pkg/schema someday?
|
||||
type TokenUsage struct {
|
||||
Prompt int
|
||||
Completion int
|
||||
}
|
||||
|
||||
type TemplateConfigBindingFn func(*schema.Config) *string
|
||||
|
||||
// type LLMStreamProcessor func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse)
|
||||
|
||||
/////// CONSTS ///////////
|
||||
|
||||
const DEFAULT_NO_ACTION_NAME = "answer"
|
||||
const DEFAULT_NO_ACTION_DESCRIPTION = "use this action to answer without performing any action"
|
||||
|
||||
////// INFERENCE /////////
|
||||
|
||||
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
grpcOpts := gRPCModelOpts(c)
|
||||
|
||||
var inferenceModel *grpc.Client
|
||||
var err error
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||
})
|
||||
|
||||
if c.Backend != "" {
|
||||
opts = append(opts, model.WithBackendString(c.Backend))
|
||||
}
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Backend == "" {
|
||||
inferenceModel, err = loader.GreedyLoader(opts...)
|
||||
} else {
|
||||
inferenceModel, err = loader.BackendLoader(opts...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
fn := func() (LLMResponse, error) {
|
||||
opts := gRPCPredictOpts(c, loader.ModelPath)
|
||||
opts.Prompt = s
|
||||
opts.Images = images
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
// check the per-model feature flag for usage, since tokenCallback may have a cost.
|
||||
// Defaults to off as for now it is still experimental
|
||||
if c.FeatureFlag.Enabled("usage") {
|
||||
userTokenCallback := tokenCallback
|
||||
if userTokenCallback == nil {
|
||||
userTokenCallback = func(token string, usage TokenUsage) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
|
||||
if pErr == nil && promptInfo.Length > 0 {
|
||||
tokenUsage.Prompt = int(promptInfo.Length)
|
||||
}
|
||||
|
||||
tokenCallback = func(token string, usage TokenUsage) bool {
|
||||
tokenUsage.Completion++
|
||||
return userTokenCallback(token, tokenUsage)
|
||||
}
|
||||
}
|
||||
|
||||
if tokenCallback != nil {
|
||||
ss := ""
|
||||
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
|
||||
partialRune = append(partialRune, chars...)
|
||||
|
||||
for len(partialRune) > 0 {
|
||||
r, size := utf8.DecodeRune(partialRune)
|
||||
if r == utf8.RuneError {
|
||||
// incomplete rune, wait for more bytes
|
||||
break
|
||||
}
|
||||
|
||||
tokenCallback(string(r), tokenUsage)
|
||||
ss += string(r)
|
||||
|
||||
partialRune = partialRune[size:]
|
||||
}
|
||||
})
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
reply, err := inferenceModel.Predict(ctx, opts)
|
||||
if err != nil {
|
||||
return LLMResponse{}, err
|
||||
}
|
||||
return LLMResponse{
|
||||
Response: string(reply.Message),
|
||||
Usage: tokenUsage,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
||||
var mu sync.Mutex = sync.Mutex{}
|
||||
|
||||
func Finetune(config schema.Config, input, prediction string) string {
|
||||
if config.Echo {
|
||||
prediction = input + prediction
|
||||
}
|
||||
|
||||
for _, c := range config.Cutstrings {
|
||||
mu.Lock()
|
||||
reg, ok := cutstrings[c]
|
||||
if !ok {
|
||||
cutstrings[c] = regexp.MustCompile(c)
|
||||
reg = cutstrings[c]
|
||||
}
|
||||
mu.Unlock()
|
||||
prediction = reg.ReplaceAllString(prediction, "")
|
||||
}
|
||||
|
||||
for _, c := range config.TrimSpace {
|
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
||||
}
|
||||
|
||||
for _, c := range config.TrimSuffix {
|
||||
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
|
||||
}
|
||||
return prediction
|
||||
|
||||
}
|
||||
|
||||
////// CONFIG AND REQUEST HANDLING ///////////////
|
||||
|
||||
func ReadConfigFromFileAndCombineWithOpenAIRequest(modelFile string, input *schema.OpenAIRequest, cm *services.ConfigLoader, startupOptions *schema.StartupOptions) (*schema.Config, *schema.OpenAIRequest, error) {
|
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(startupOptions.ModelPath, modelFile+".yaml")
|
||||
|
||||
var cfg *schema.Config
|
||||
|
||||
defaults := func() {
|
||||
cfg = schema.DefaultConfig(modelFile)
|
||||
cfg.ContextSize = startupOptions.ContextSize
|
||||
cfg.Threads = startupOptions.Threads
|
||||
cfg.F16 = startupOptions.F16
|
||||
cfg.Debug = startupOptions.Debug
|
||||
}
|
||||
|
||||
cfgExisting, exists := cm.GetConfig(modelFile)
|
||||
if !exists {
|
||||
if _, err := os.Stat(modelConfig); err == nil {
|
||||
if err := cm.LoadConfig(modelConfig); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
|
||||
}
|
||||
cfgExisting, exists = cm.GetConfig(modelFile)
|
||||
if exists {
|
||||
cfg = &cfgExisting
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
defaults()
|
||||
}
|
||||
} else {
|
||||
cfg = &cfgExisting
|
||||
}
|
||||
|
||||
// Set the parameters for the language model prediction
|
||||
schema.UpdateConfigFromOpenAIRequest(cfg, input)
|
||||
|
||||
// Don't allow 0 as setting
|
||||
if cfg.Threads == 0 {
|
||||
if startupOptions.Threads != 0 {
|
||||
cfg.Threads = startupOptions.Threads
|
||||
} else {
|
||||
cfg.Threads = 4
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if startupOptions.Debug {
|
||||
cfg.Debug = true
|
||||
}
|
||||
|
||||
return cfg, input, nil
|
||||
}
|
||||
|
||||
func ComputeChoices(
|
||||
req *schema.OpenAIRequest,
|
||||
predInput string,
|
||||
config *schema.Config,
|
||||
o *schema.StartupOptions,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
tokenCallback func(string, TokenUsage) bool) ([]schema.Choice, TokenUsage, error) {
|
||||
n := req.N // number of completions to return
|
||||
result := []schema.Choice{}
|
||||
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range req.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
|
||||
if err != nil {
|
||||
return result, TokenUsage{}, err
|
||||
}
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, TokenUsage{}, err
|
||||
}
|
||||
|
||||
tokenUsage.Prompt += prediction.Usage.Prompt
|
||||
tokenUsage.Completion += prediction.Usage.Completion
|
||||
|
||||
finetunedResponse := Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, tokenUsage, err
|
||||
}
|
||||
|
||||
// TODO: No functions???? Commonize with prepareChatGenerationOpenAIRequest below?
|
||||
func prepareGenerationOpenAIRequest(bindingFn TemplateConfigBindingFn, modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, error) {
|
||||
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
configTemplate := bindingFn(config)
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if (*configTemplate == "") && (ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model))) {
|
||||
*configTemplate = config.Model
|
||||
}
|
||||
if *configTemplate == "" {
|
||||
return nil, fmt.Errorf(("failed to find templateConfig"))
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
////////// SPECIFIC REQUESTS //////////////
|
||||
// TODO: For round one of the refactor, give each of the three primary text endpoints their own function?
|
||||
// SEMITODO: During a merge, edit/completion were semi-combined - but remain nominally split
|
||||
// Can cleanup into a common form later if possible easier if they are all here for now
|
||||
// If they remain different, extract each of these named segments to a seperate file
|
||||
|
||||
func prepareChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.Config, string, bool, error) {
|
||||
|
||||
// IMPORTANT DEFS
|
||||
funcs := grammar.Functions{}
|
||||
|
||||
// The Basic Begining
|
||||
|
||||
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
log.Debug().Msgf("Configuration read: %+v", config)
|
||||
|
||||
// Special Input/Config Handling
|
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model - but if they are missing, use defaults.
|
||||
if config.FunctionsConfig.NoActionFunctionName == "" {
|
||||
config.FunctionsConfig.NoActionFunctionName = DEFAULT_NO_ACTION_NAME
|
||||
}
|
||||
if config.FunctionsConfig.NoActionDescriptionName == "" {
|
||||
config.FunctionsConfig.NoActionDescriptionName = DEFAULT_NO_ACTION_DESCRIPTION
|
||||
}
|
||||
|
||||
if input.ResponseFormat.Type == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
processFunctions := len(input.Functions) > 0 && config.ShouldUseFunctions()
|
||||
|
||||
if processFunctions {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
|
||||
noActionGrammar := grammar.Function{
|
||||
Name: config.FunctionsConfig.NoActionFunctionName,
|
||||
Description: config.FunctionsConfig.NoActionDescriptionName,
|
||||
Parameters: map[string]interface{}{
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to reply the user with",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...)
|
||||
if !config.FunctionsConfig.DisableNoAction {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
funcs = funcs.Select(config.FunctionToCall())
|
||||
}
|
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure()
|
||||
config.Grammar = jsStruct.Grammar("")
|
||||
} else if input.JSONFunctionGrammarObject != nil {
|
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config)
|
||||
|
||||
var predInput string
|
||||
|
||||
suppressConfigSystemPrompt := false
|
||||
mess := []string{}
|
||||
for messageIndex, i := range input.Messages {
|
||||
var content string
|
||||
role := i.Role
|
||||
|
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" {
|
||||
roleFn := "assistant_function_call"
|
||||
r := config.Roles[roleFn]
|
||||
if r != "" {
|
||||
role = roleFn
|
||||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: i.StringContent,
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
if err != nil {
|
||||
log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err)
|
||||
} else {
|
||||
if templatedChatMessage == "" {
|
||||
log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf
|
||||
}
|
||||
log.Debug().Msgf("templated message for chat: %s", templatedChatMessage)
|
||||
content = templatedChatMessage
|
||||
}
|
||||
}
|
||||
// If this model doesn't have such a template, or if that template fails to return a value, template at the message level.
|
||||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + fmt.Sprint(r, " ", string(j))
|
||||
} else {
|
||||
content = fmt.Sprint(r, " ", string(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
if err == nil {
|
||||
if contentExists {
|
||||
content += "\n" + string(j)
|
||||
} else {
|
||||
content = string(j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Special Handling: System. We care if it was printed at all, not the r branch, so check seperately
|
||||
if contentExists && role == "system" {
|
||||
suppressConfigSystemPrompt = true
|
||||
}
|
||||
}
|
||||
|
||||
mess = append(mess, content)
|
||||
}
|
||||
|
||||
predInput = strings.Join(mess, "\n")
|
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput)
|
||||
|
||||
templateFile := ""
|
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
|
||||
templateFile = config.Model
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions {
|
||||
templateFile = config.TemplateConfig.Chat
|
||||
}
|
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions {
|
||||
templateFile = config.TemplateConfig.Functions
|
||||
}
|
||||
|
||||
if templateFile != "" {
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
SuppressSystemPrompt: suppressConfigSystemPrompt,
|
||||
Input: predInput,
|
||||
Functions: funcs,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
} else {
|
||||
log.Debug().Msgf("Template failed loading: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
if processFunctions {
|
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar)
|
||||
}
|
||||
|
||||
return config, predInput, processFunctions, nil
|
||||
|
||||
}
|
||||
|
||||
func EditGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
binding := func(config *schema.Config) *string {
|
||||
return &config.TemplateConfig.Edit
|
||||
}
|
||||
|
||||
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
totalTokenUsage := TokenUsage{}
|
||||
|
||||
for _, i := range config.InputStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, config.TemplateConfig.Edit, model.PromptTemplateData{
|
||||
Input: i,
|
||||
Instruction: input.Instruction,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
return &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "edit",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||
|
||||
// DEFS
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Prepare
|
||||
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if processFunctions {
|
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{}
|
||||
// This prevent newlines to break JSON parsing for clients
|
||||
s = utils.EscapeNewLines(s)
|
||||
json.Unmarshal([]byte(s), &ss)
|
||||
log.Debug().Msgf("Function return: %s %+v", s, ss)
|
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"]
|
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args)
|
||||
|
||||
ss["arguments"] = string(d)
|
||||
ss["name"] = func_name
|
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == config.FunctionsConfig.NoActionFunctionName {
|
||||
log.Debug().Msgf("nothing to do, computing a reply")
|
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(d), &arguments)
|
||||
m, exists := arguments["message"]
|
||||
if exists {
|
||||
switch message := m.(type) {
|
||||
case string:
|
||||
if message != "" {
|
||||
log.Debug().Msgf("Reply received from LLM: %s", message)
|
||||
message = Finetune(*config, predInput, message)
|
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)
|
||||
|
||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
|
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = ""
|
||||
images := []string{}
|
||||
for _, m := range input.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
predFunc, err := ModelInference(input.Context, predInput, images, ml, *config, startupOptions, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fineTunedResponse := Finetune(*config, predInput, prediction.Response)
|
||||
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
|
||||
} else {
|
||||
// otherwise reply with the function call
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "function_call",
|
||||
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "chat.completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
},
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func CompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.OpenAIResponse, error) {
|
||||
// Prepare
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
binding := func(config *schema.Config) *string {
|
||||
return &config.TemplateConfig.Completion
|
||||
}
|
||||
|
||||
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
|
||||
totalTokenUsage := TokenUsage{}
|
||||
|
||||
for k, i := range config.PromptStrings {
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Input: i,
|
||||
})
|
||||
if err == nil {
|
||||
i = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", i)
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(
|
||||
input, i, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
totalTokenUsage.Prompt += tokenUsage.Prompt
|
||||
totalTokenUsage.Completion += tokenUsage.Completion
|
||||
|
||||
result = append(result, r...)
|
||||
}
|
||||
|
||||
return &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result,
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: totalTokenUsage.Prompt,
|
||||
CompletionTokens: totalTokenUsage.Completion,
|
||||
TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func StreamingChatGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
|
||||
|
||||
// DEFS
|
||||
emptyMessage := ""
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Prepare
|
||||
config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if processFunctions {
|
||||
// TODO: unused variable means I did something wrong. investigate once stable
|
||||
log.Debug().Msgf("StreamingChatGenerationOpenAIRequest with processFunctions=true for %s?", config.Name)
|
||||
}
|
||||
|
||||
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
|
||||
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to create response channel")
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to start processor goroutine")
|
||||
|
||||
go processor(predInput, input, config, ml, responses)
|
||||
|
||||
log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: DONE! successfully returning to caller!")
|
||||
|
||||
return responses, nil
|
||||
|
||||
}
|
||||
|
||||
func StreamingCompletionGenerationOpenAIRequest(modelName string, input *schema.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (chan schema.OpenAIResponse, error) {
|
||||
// DEFS
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
binding := func(config *schema.Config) *string {
|
||||
return &config.TemplateConfig.Completion
|
||||
}
|
||||
|
||||
// Prepare
|
||||
|
||||
config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processor := func(s string, req *schema.OpenAIRequest, config *schema.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
|
||||
ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage TokenUsage) bool {
|
||||
resp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Text: s,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
Usage: schema.OpenAIUsage{
|
||||
PromptTokens: usage.Prompt,
|
||||
CompletionTokens: usage.Completion,
|
||||
TotalTokens: usage.Prompt + usage.Completion,
|
||||
},
|
||||
}
|
||||
log.Debug().Msgf("Sending goroutine: %s", s)
|
||||
|
||||
responses <- resp
|
||||
return true
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return nil, errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
|
||||
}
|
||||
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
//A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{
|
||||
Input: predInput,
|
||||
})
|
||||
if err == nil {
|
||||
predInput = templatedInput
|
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput)
|
||||
}
|
||||
|
||||
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to create response channel")
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
|
||||
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to start processor goroutine")
|
||||
|
||||
go processor(predInput, input, config, ml, responses)
|
||||
|
||||
log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: DONE! successfully returning to caller!")
|
||||
|
||||
return responses, nil
|
||||
}
|
@ -5,13 +5,11 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
)
|
||||
|
||||
func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option {
|
||||
func modelOpts(c schema.Config, o *schema.StartupOptions, opts []model.Option) []model.Option {
|
||||
if o.SingleBackend {
|
||||
opts = append(opts, model.WithSingleActiveBackend())
|
||||
}
|
||||
@ -35,7 +33,7 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.
|
||||
return opts
|
||||
}
|
||||
|
||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
||||
func gRPCModelOpts(c schema.Config) *pb.ModelOptions {
|
||||
b := 512
|
||||
if c.Batch != 0 {
|
||||
b = c.Batch
|
||||
@ -82,7 +80,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions {
|
||||
func gRPCPredictOpts(c schema.Config, modelPath string) *pb.PredictOptions {
|
||||
promptCachePath := ""
|
||||
if c.PromptCachePath != "" {
|
||||
p := filepath.Join(modelPath, c.PromptCachePath)
|
52
core/backend/transcription.go
Normal file
52
core/backend/transcription.go
Normal file
@ -0,0 +1,52 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c schema.Config, o *schema.StartupOptions) (*schema.WhisperResult, error) {
|
||||
|
||||
opts := modelOpts(c, o, []model.Option{
|
||||
model.WithBackendString(model.WhisperBackend),
|
||||
model.WithModel(c.Model),
|
||||
model.WithContext(o.Context),
|
||||
model.WithThreads(uint32(c.Threads)),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||
})
|
||||
|
||||
whisperModel, err := loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if whisperModel == nil {
|
||||
return nil, fmt.Errorf("could not load whisper model")
|
||||
}
|
||||
|
||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||
Dst: audio,
|
||||
Language: language,
|
||||
Threads: uint32(c.Threads),
|
||||
})
|
||||
}
|
||||
|
||||
func TranscriptionOpenAIRequest(modelName string, input *schema.OpenAIRequest, audioFilePath string, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) (*schema.WhisperResult, error) {
|
||||
config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
tr, err := ModelTranscription(audioFilePath, input.Language, ml, *config, startupOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tr, nil
|
||||
}
|
@ -6,10 +6,9 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
api_config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
@ -29,18 +28,19 @@ func generateUniqueFileName(dir, baseName, ext string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
|
||||
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *schema.StartupOptions) (string, *proto.Result, error) {
|
||||
bb := backend
|
||||
if bb == "" {
|
||||
bb = model.PiperBackend
|
||||
}
|
||||
opts := modelOpts(api_config.Config{}, o, []model.Option{
|
||||
opts := modelOpts(schema.Config{}, o, []model.Option{
|
||||
model.WithBackendString(bb),
|
||||
model.WithModel(modelFile),
|
||||
model.WithContext(o.Context),
|
||||
model.WithAssetDir(o.AssetsDestination),
|
||||
model.WithExternalBackends(o.ExternalGRPCBackends, false),
|
||||
})
|
||||
piperModel, err := o.Loader.BackendLoader(opts...)
|
||||
piperModel, err := loader.BackendLoader(opts...)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@ -60,8 +60,8 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt
|
||||
modelPath := ""
|
||||
if modelFile != "" {
|
||||
if bb != model.TransformersMusicGen {
|
||||
modelPath = filepath.Join(o.Loader.ModelPath, modelFile)
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
|
||||
modelPath = filepath.Join(o.ModelPath, modelFile)
|
||||
if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else {
|
169
core/http/api.go
Normal file
169
core/http/api.go
Normal file
@ -0,0 +1,169 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
)
|
||||
|
||||
func App(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*fiber.App, error) {
|
||||
|
||||
// Return errors as JSON responses
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
DisableStartupMessage: options.DisableMessage,
|
||||
// Override default error handler
|
||||
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
code := fiber.StatusInternalServerError
|
||||
|
||||
// Retrieve the custom status code if it's a *fiber.Error
|
||||
var e *fiber.Error
|
||||
if errors.As(err, &e) {
|
||||
code = e.Code
|
||||
}
|
||||
|
||||
// Send custom error page
|
||||
return ctx.Status(code).JSON(
|
||||
schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
},
|
||||
)
|
||||
},
|
||||
})
|
||||
|
||||
if options.Debug {
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
|
||||
}))
|
||||
}
|
||||
|
||||
// Default middleware config
|
||||
app.Use(recover.New())
|
||||
|
||||
if options.Metrics != nil {
|
||||
app.Use(localai.MetricsAPIMiddleware(options.Metrics))
|
||||
}
|
||||
|
||||
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
|
||||
auth := func(c *fiber.Ctx) error {
|
||||
if len(options.ApiKeys) == 0 {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
|
||||
}
|
||||
authHeaderParts := strings.Split(authHeader, " ")
|
||||
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
|
||||
}
|
||||
|
||||
apiKey := authHeaderParts[1]
|
||||
for _, key := range options.ApiKeys {
|
||||
if apiKey == key {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
|
||||
|
||||
}
|
||||
|
||||
if options.CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if options.CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
|
||||
}
|
||||
|
||||
app.Use(c)
|
||||
}
|
||||
|
||||
// LocalAI API endpoints
|
||||
galleryService := services.NewGalleryApplier(options.ModelPath)
|
||||
galleryService.Start(options.Context, cl)
|
||||
|
||||
app.Get("/version", auth, func(c *fiber.Ctx) error {
|
||||
return c.JSON(struct {
|
||||
Version string `json:"version"`
|
||||
}{Version: internal.PrintableVersion()})
|
||||
})
|
||||
|
||||
modelGalleryService := localai.CreateModelGalleryEndpointService(options.Galleries, options.ModelPath, galleryService)
|
||||
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
|
||||
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
|
||||
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
|
||||
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
|
||||
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
|
||||
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
|
||||
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
|
||||
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
|
||||
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, options))
|
||||
|
||||
// edit
|
||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, options))
|
||||
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, options))
|
||||
|
||||
// completion
|
||||
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, options))
|
||||
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, options))
|
||||
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, options))
|
||||
|
||||
// embeddings
|
||||
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
|
||||
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
|
||||
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options))
|
||||
|
||||
// audio
|
||||
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, options))
|
||||
app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, options))
|
||||
|
||||
// images
|
||||
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, options))
|
||||
|
||||
if options.ImageDir != "" {
|
||||
app.Static("/generated-images", options.ImageDir)
|
||||
}
|
||||
|
||||
if options.AudioDir != "" {
|
||||
app.Static("/generated-audio", options.AudioDir)
|
||||
}
|
||||
|
||||
ok := func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(200)
|
||||
}
|
||||
|
||||
// Kubernetes health checks
|
||||
app.Get("/healthz", ok)
|
||||
app.Get("/readyz", ok)
|
||||
|
||||
app.Get("/metrics", localai.MetricsHandler())
|
||||
|
||||
backendMonitor := services.NewBackendMonitor(cl, ml, options)
|
||||
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
|
||||
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
|
||||
|
||||
// model listing
|
||||
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))
|
||||
|
||||
return app, nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package api_test
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -13,11 +13,12 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/api"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/metrics"
|
||||
server "github.com/go-skynet/LocalAI/core/http"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/core/startup"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@ -118,16 +119,15 @@ var backendAssets embed.FS
|
||||
var _ = Describe("API test", func() {
|
||||
|
||||
var app *fiber.App
|
||||
var modelLoader *model.ModelLoader
|
||||
var client *openai.Client
|
||||
var client2 *openaigo.Client
|
||||
var c context.Context
|
||||
var cancel context.CancelFunc
|
||||
var tmpdir string
|
||||
|
||||
commonOpts := []options.AppOption{
|
||||
options.WithDebug(true),
|
||||
options.WithDisableMessage(true),
|
||||
commonOpts := []schema.AppOption{
|
||||
schema.WithDebug(true),
|
||||
schema.WithDisableMessage(true),
|
||||
}
|
||||
|
||||
Context("API with ephemeral models", func() {
|
||||
@ -136,7 +136,6 @@ var _ = Describe("API test", func() {
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(tmpdir)
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
g := []gallery.GalleryModel{
|
||||
@ -163,15 +162,20 @@ var _ = Describe("API test", func() {
|
||||
},
|
||||
}
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
metricsService, err := services.SetupMetrics()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(
|
||||
cl, ml, options, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
options.WithMetrics(metricsService),
|
||||
options.WithContext(c),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
|
||||
schema.WithMetrics(metricsService),
|
||||
schema.WithContext(c),
|
||||
schema.WithGalleries(galleries),
|
||||
schema.WithModelPath(tmpdir),
|
||||
schema.WithBackendAssets(backendAssets),
|
||||
schema.WithBackendAssetsOutput(tmpdir))...)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = server.App(cl, ml, options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
@ -475,7 +479,6 @@ var _ = Describe("API test", func() {
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelLoader = model.NewModelLoader(tmpdir)
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
galleries := []gallery.Gallery{
|
||||
@ -485,21 +488,22 @@ var _ = Describe("API test", func() {
|
||||
},
|
||||
}
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
metricsService, err := services.SetupMetrics()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(
|
||||
cl, ml, options, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithMetrics(metricsService),
|
||||
options.WithAudioDir(tmpdir),
|
||||
options.WithImageDir(tmpdir),
|
||||
options.WithGalleries(galleries),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(tmpdir))...,
|
||||
schema.WithContext(c),
|
||||
schema.WithMetrics(metricsService),
|
||||
schema.WithAudioDir(tmpdir),
|
||||
schema.WithImageDir(tmpdir),
|
||||
schema.WithGalleries(galleries),
|
||||
schema.WithModelPath(tmpdir),
|
||||
schema.WithBackendAssets(backendAssets),
|
||||
schema.WithBackendAssetsOutput(tmpdir))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = server.App(cl, ml, options)
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
@ -590,20 +594,21 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Context("API query", func() {
|
||||
BeforeEach(func() {
|
||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
metricsService, err := services.SetupMetrics()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(
|
||||
cl, ml, options, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
options.WithContext(c),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithMetrics(metricsService),
|
||||
schema.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
schema.WithContext(c),
|
||||
schema.WithModelPath(os.Getenv("MODELS_PATH")),
|
||||
schema.WithMetrics(metricsService),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = server.App(cl, ml, options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
@ -802,20 +807,21 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Context("Config file", func() {
|
||||
BeforeEach(func() {
|
||||
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
metricsService, err := metrics.SetupMetrics()
|
||||
metricsService, err := services.SetupMetrics()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = App(
|
||||
cl, ml, options, err := startup.Startup(
|
||||
append(commonOpts,
|
||||
options.WithContext(c),
|
||||
options.WithMetrics(metricsService),
|
||||
options.WithModelLoader(modelLoader),
|
||||
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
schema.WithContext(c),
|
||||
schema.WithMetrics(metricsService),
|
||||
schema.WithModelPath(os.Getenv("MODELS_PATH")),
|
||||
schema.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = server.App(cl, ml, options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
@ -1,4 +1,4 @@
|
||||
package api_test
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"testing"
|
34
core/http/endpoints/localai/backend_monitor.go
Normal file
34
core/http/endpoints/localai/backend_monitor.go
Normal file
@ -0,0 +1,34 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
148
core/http/endpoints/localai/gallery.go
Normal file
148
core/http/endpoints/localai/gallery.go
Normal file
@ -0,0 +1,148 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
/// Endpoint Service
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []gallery.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryApplier
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryApplier) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
modelPath: modelPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.C <- gallery.GalleryOp{
|
||||
Req: input.GalleryModel,
|
||||
Id: uuid.String(),
|
||||
GalleryName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
}
|
||||
return c.JSON(struct {
|
||||
ID string `json:"uuid"`
|
||||
StatusURL string `json:"status"`
|
||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("Models found from galleries: %+v", models)
|
||||
for _, m := range models {
|
||||
log.Debug().Msgf("Model found from galleries: %+v", m)
|
||||
}
|
||||
dat, err := json.Marshal(models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(gallery.Gallery)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
}) {
|
||||
return fmt.Errorf("%s already exists", input.Name)
|
||||
}
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("Adding %+v to gallery list", *input)
|
||||
mgs.galleries = append(mgs.galleries, *input)
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(gallery.Gallery)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
}) {
|
||||
return fmt.Errorf("%s is not currently registered", input.Name)
|
||||
}
|
||||
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
|
||||
return gallery.Name == input.Name
|
||||
})
|
||||
return c.Send(nil)
|
||||
}
|
||||
}
|
42
core/http/endpoints/localai/metrics.go
Normal file
42
core/http/endpoints/localai/metrics.go
Normal file
@ -0,0 +1,42 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
func MetricsHandler() fiber.Handler {
|
||||
return adaptor.HTTPHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c *fiber.Ctx) bool
|
||||
metrics *schema.LocalAIMetrics
|
||||
}
|
||||
|
||||
func MetricsAPIMiddleware(metrics *schema.LocalAIMetrics) fiber.Handler {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metrics: metrics,
|
||||
Filter: func(c *fiber.Ctx) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return c.Next()
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Method()
|
||||
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metrics.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
25
core/http/endpoints/localai/tts.go
Normal file
25
core/http/endpoints/localai/tts.go
Normal file
@ -0,0 +1,25 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func TTSEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.TTSRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, ml, so)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
}
|
||||
}
|
97
core/http/endpoints/openai/chat.go
Normal file
97
core/http/endpoints/openai/chat.go
Normal file
@ -0,0 +1,97 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func ChatEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||
|
||||
emptyMessage := ""
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName, input, err := readInput(c, startupOptions, ml, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
// The scary comment I feel like I forgot about along the way:
|
||||
//
|
||||
// functions are not supported in stream mode (yet?)
|
||||
//
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
|
||||
responses, err := backend.StreamingChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed establishing streaming chat request :%w", err)
|
||||
}
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
usage := &schema.OpenAIUsage{}
|
||||
id := ""
|
||||
created := 0
|
||||
for ev := range responses {
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
id = ev.ID
|
||||
created = ev.Created // Similarly, grab the ID and created from any / the last response so we can use it for the stop
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
break
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: &emptyMessage},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
//////////////////////////////////////////
|
||||
|
||||
resp, err := backend.ChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating chat request: +%w", err)
|
||||
}
|
||||
respData, _ := json.Marshal(resp) // TODO this is only used for the debug log and costs performance. monitor this?
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
91
core/http/endpoints/openai/completion.go
Normal file
91
core/http/endpoints/openai/completion.go
Normal file
@ -0,0 +1,91 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func CompletionEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||
id := uuid.New().String()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName, input, err := readInput(c, so, ml, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("`input`: %+v", input)
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
|
||||
responses, err := backend.StreamingCompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed establishing streaming completion request :%w", err)
|
||||
}
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
|
||||
for ev := range responses {
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////
|
||||
|
||||
resp, err := backend.CompletionGenerationOpenAIRequest(modelName, input, cl, ml, so)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating completion request: +%w", err)
|
||||
}
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
34
core/http/endpoints/openai/edit.go
Normal file
34
core/http/endpoints/openai/edit.go
Normal file
@ -0,0 +1,34 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func EditEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readInput(c, so, ml, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
resp, err := backend.EditGenerationOpenAIRequest(modelFile, input, cl, ml, so)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
35
core/http/endpoints/openai/embeddings.go
Normal file
35
core/http/endpoints/openai/embeddings.go
Normal file
@ -0,0 +1,35 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func EmbeddingsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelFile, input, err := readInput(c, so, ml, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
resp, err := backend.EmbeddingOpenAIRequest(modelFile, input, cl, ml, so)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
48
core/http/endpoints/openai/image.go
Normal file
48
core/http/endpoints/openai/image.go
Normal file
@ -0,0 +1,48 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/*
|
||||
*
|
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "A cute baby sea otter",
|
||||
"n": 1,
|
||||
"size": "512x512"
|
||||
}'
|
||||
|
||||
*
|
||||
*/
|
||||
func ImageEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName, input, err := readInput(c, so, ml, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
resp, err := backend.ImageGenerationOpenAIRequest(modelName, input, cl, ml, so)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating image request: +%w", err)
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
@ -3,21 +3,21 @@ package openai
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
models, err := loader.ListModels()
|
||||
models, err := ml.ListModels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var mm map[string]interface{} = map[string]interface{}{}
|
||||
|
||||
dataModels := []schema.OpenAIModel{}
|
||||
openAIModels := []schema.OpenAIModel{}
|
||||
|
||||
var filterFn func(name string) bool
|
||||
filter := c.Query("filter")
|
||||
@ -40,13 +40,13 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
||||
excludeConfigured := c.QueryBool("excludeConfigured", true)
|
||||
|
||||
// Start with the known configurations
|
||||
for _, c := range cm.GetAllConfigs() {
|
||||
for _, c := range cl.GetAllConfigs() {
|
||||
if excludeConfigured {
|
||||
mm[c.Model] = nil
|
||||
}
|
||||
|
||||
if filterFn(c.Name) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
|
||||
openAIModels = append(openAIModels, schema.OpenAIModel{ID: c.Name, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,7 +54,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
||||
for _, m := range models {
|
||||
// And only adds them if they shouldn't be skipped.
|
||||
if _, exists := mm[m]; !exists && filterFn(m) {
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
openAIModels = append(openAIModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func
|
||||
Data []schema.OpenAIModel `json:"data"`
|
||||
}{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
Data: openAIModels,
|
||||
})
|
||||
}
|
||||
}
|
57
core/http/endpoints/openai/request.go
Normal file
57
core/http/endpoints/openai/request.go
Normal file
@ -0,0 +1,57 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func readInput(c *fiber.Ctx, o *schema.StartupOptions, ml *model.ModelLoader, randomModel bool) (string, *schema.OpenAIRequest, error) {
|
||||
input := new(schema.OpenAIRequest)
|
||||
ctx, cancel := context.WithCancel(o.Context)
|
||||
input.Context = ctx
|
||||
input.Cancel = cancel
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
modelFile := input.Model
|
||||
|
||||
if c.Params("model") != "" {
|
||||
modelFile = c.Params("model")
|
||||
}
|
||||
|
||||
received, _ := json.Marshal(input)
|
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received))
|
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
|
||||
bearerExists := bearer != "" && ml.ExistsInModelPath(bearer)
|
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel {
|
||||
models, _ := ml.ListModels()
|
||||
if len(models) > 0 {
|
||||
modelFile = models[0]
|
||||
log.Debug().Msgf("No model specified, using: %s", modelFile)
|
||||
} else {
|
||||
log.Debug().Msgf("No model specified, returning error")
|
||||
return "", nil, fmt.Errorf("no model specified")
|
||||
}
|
||||
}
|
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists {
|
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer)
|
||||
modelFile = bearer
|
||||
}
|
||||
return modelFile, input, nil
|
||||
}
|
49
core/http/endpoints/openai/transcription.go
Normal file
49
core/http/endpoints/openai/transcription.go
Normal file
@ -0,0 +1,49 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func TranscriptEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName, input, err := readInput(c, so, ml, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst, err := utils.CreateTempFileFromMultipartFile(file, "", "transcription") // 3rd param formerly whisper
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
defer os.RemoveAll(path.Dir(dst))
|
||||
|
||||
tr, err := backend.TranscriptionOpenAIRequest(modelName, input, dst, cl, ml, so)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating transcription request: +%w", err)
|
||||
}
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr)
|
||||
}
|
||||
}
|
24
core/mqtt/manager.go
Normal file
24
core/mqtt/manager.go
Normal file
@ -0,0 +1,24 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
)
|
||||
|
||||
// PLACEHOLDER DURING PART 1 OF THE REFACTOR
|
||||
|
||||
type MQTTManager struct {
|
||||
configLoader *services.ConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
startupOptions *schema.StartupOptions
|
||||
}
|
||||
|
||||
func NewMQTTManager(cl *services.ConfigLoader, ml *model.ModelLoader, options *schema.StartupOptions) (*MQTTManager, error) {
|
||||
|
||||
return &MQTTManager{
|
||||
configLoader: cl,
|
||||
modelLoader: ml,
|
||||
startupOptions: options,
|
||||
}, nil
|
||||
}
|
138
core/services/backend_monitor.go
Normal file
138
core/services/backend_monitor.go
Normal file
@ -0,0 +1,138 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
type BackendMonitor struct {
|
||||
configLoader *ConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
options *schema.StartupOptions // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
|
||||
}
|
||||
|
||||
func NewBackendMonitor(configLoader *ConfigLoader, modelLoader *model.ModelLoader, options *schema.StartupOptions) *BackendMonitor {
|
||||
return &BackendMonitor{
|
||||
configLoader: configLoader,
|
||||
modelLoader: modelLoader,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
|
||||
config, exists := bm.configLoader.GetConfig(model)
|
||||
var backend string
|
||||
if exists {
|
||||
backend = config.Model
|
||||
} else {
|
||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||
backend = model
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(backend, ".bin") {
|
||||
backend = fmt.Sprintf("%s.bin", backend)
|
||||
}
|
||||
|
||||
pid, err := bm.modelLoader.GetGRPCPID(backend)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s : failed to find pid %+v", model, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID.
|
||||
backendProcess, err := gopsutil.NewProcess(int32(pid))
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memInfo, err := backendProcess.MemoryInfo()
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memPercent, err := backendProcess.MemoryPercent()
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cpuPercent, err := backendProcess.CPUPercent()
|
||||
if err != nil {
|
||||
log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &schema.BackendMonitorResponse{
|
||||
MemoryInfo: memInfo,
|
||||
MemoryPercent: memPercent,
|
||||
CPUPercent: cpuPercent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
|
||||
config, exists := bm.configLoader.GetConfig(modelName)
|
||||
var backendId string
|
||||
if exists {
|
||||
backendId = config.Model
|
||||
} else {
|
||||
// Last ditch effort: use it raw, see if a backend happens to match.
|
||||
backendId = modelName
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(backendId, ".bin") {
|
||||
backendId = fmt.Sprintf("%s.bin", backendId)
|
||||
}
|
||||
|
||||
return backendId, nil
|
||||
}
|
||||
|
||||
func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
|
||||
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelAddr := bm.modelLoader.CheckIsLoaded(backendId)
|
||||
if modelAddr == "" {
|
||||
return nil, fmt.Errorf("backend %s is not currently loaded", backendId)
|
||||
}
|
||||
|
||||
status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO())
|
||||
if rpcErr != nil {
|
||||
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
|
||||
val, slbErr := bm.SampleLocalBackendProcess(backendId)
|
||||
if slbErr != nil {
|
||||
return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error())
|
||||
}
|
||||
return &proto.StatusResponse{
|
||||
State: proto.StatusResponse_ERROR,
|
||||
Memory: &proto.MemoryUsageData{
|
||||
Total: val.MemoryInfo.VMS,
|
||||
Breakdown: map[string]uint64{
|
||||
"gopsutil-RSS": val.MemoryInfo.RSS,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (bm BackendMonitor) ShutdownModel(modelName string) error {
|
||||
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bm.modelLoader.ShutdownModel(backendId)
|
||||
}
|
157
core/services/config.go
Normal file
157
core/services/config.go
Normal file
@ -0,0 +1,157 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ConfigLoader struct {
|
||||
configs map[string]schema.Config
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewConfigLoader() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
configs: make(map[string]schema.Config),
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check this is correct post-merge
|
||||
func (cm *ConfigLoader) LoadConfig(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := schema.ReadSingleConfigFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read config file: %w", err)
|
||||
}
|
||||
|
||||
cm.configs[c.Name] = *c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) GetConfig(m string) (schema.Config, bool) {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
v, exists := cm.configs[m]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) GetAllConfigs() []schema.Config {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []schema.Config
|
||||
for _, v := range cm.configs {
|
||||
res = append(res, v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) ListConfigs() []string {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []string
|
||||
for k := range cm.configs {
|
||||
res = append(res, k)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files := make([]fs.FileInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, info)
|
||||
}
|
||||
for _, file := range files {
|
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
||||
continue
|
||||
}
|
||||
c, err := schema.ReadSingleConfigFile(filepath.Join(path, file.Name()))
|
||||
if err == nil {
|
||||
cm.configs[c.Name] = *c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Preload prepare models if they are not local but url or huggingface repositories
|
||||
func (cm *ConfigLoader) Preload(modelPath string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
|
||||
status := func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Preloading models from %s", modelPath)
|
||||
|
||||
for _, config := range cm.configs {
|
||||
|
||||
// Download files and verify their SHA
|
||||
for _, file := range config.DownloadFiles {
|
||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||
|
||||
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Create file path
|
||||
filePath := filepath.Join(modelPath, file.Filename)
|
||||
|
||||
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
modelURL := config.PredictionOptions.Model
|
||||
modelURL = utils.ConvertURL(modelURL)
|
||||
|
||||
if utils.LooksLikeURL(modelURL) {
|
||||
// md5 of model name
|
||||
md5Name := utils.MD5(modelURL)
|
||||
|
||||
// check if file exists
|
||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
||||
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cl *ConfigLoader) LoadConfigFile(file string) error {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
c, err := schema.ReadConfigFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
}
|
||||
|
||||
for _, cc := range c {
|
||||
cl.configs[cc.Name] = *cc
|
||||
}
|
||||
return nil
|
||||
}
|
160
core/services/gallery.go
Normal file
160
core/services/gallery.go
Normal file
@ -0,0 +1,160 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type GalleryApplier struct {
|
||||
modelPath string
|
||||
sync.Mutex
|
||||
C chan gallery.GalleryOp
|
||||
statuses map[string]*gallery.GalleryOpStatus
|
||||
}
|
||||
|
||||
func NewGalleryApplier(modelPath string) *GalleryApplier {
|
||||
return &GalleryApplier{
|
||||
modelPath: modelPath,
|
||||
C: make(chan gallery.GalleryOp),
|
||||
statuses: make(map[string]*gallery.GalleryOpStatus),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GalleryApplier) UpdateStatus(s string, op *gallery.GalleryOpStatus) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
g.statuses[s] = op
|
||||
}
|
||||
|
||||
func (g *GalleryApplier) GetStatus(s string) *gallery.GalleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses[s]
|
||||
}
|
||||
|
||||
func (g *GalleryApplier) GetAllStatus() map[string]*gallery.GalleryOpStatus {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
return g.statuses
|
||||
}
|
||||
|
||||
func (g *GalleryApplier) Start(c context.Context, cm *ConfigLoader) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
case op := <-g.C:
|
||||
utils.ResetDownloadTimers()
|
||||
|
||||
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0})
|
||||
|
||||
// updates the status with an error
|
||||
updateError := func(e error) {
|
||||
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
|
||||
}
|
||||
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) {
|
||||
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage)
|
||||
}
|
||||
|
||||
var err error
|
||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||
if op.GalleryName != "" {
|
||||
if strings.Contains(op.GalleryName, "@") {
|
||||
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
|
||||
}
|
||||
} else {
|
||||
err = PrepareModel(g.modelPath, op.Req, cm, progressCallback)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
updateError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Reload models
|
||||
err = cm.LoadConfigs(g.modelPath)
|
||||
if err != nil {
|
||||
updateError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type galleryModel struct {
|
||||
gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func PrepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigLoader, downloadStatus func(string, string, string, float64)) error {
|
||||
|
||||
config, err := gallery.GetInstallableModelFromURL(req.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...)
|
||||
|
||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus)
|
||||
}
|
||||
|
||||
func processRequests(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error {
|
||||
var err error
|
||||
for _, r := range requests {
|
||||
utils.ResetDownloadTimers()
|
||||
if r.ID == "" {
|
||||
err = PrepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
if strings.Contains(r.ID, "@") {
|
||||
err = gallery.InstallModelFromGallery(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
} else {
|
||||
err = gallery.InstallModelFromGalleryByName(
|
||||
galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
|
||||
dat, err := os.ReadFile(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var requests []galleryModel
|
||||
|
||||
if err := yaml.Unmarshal(dat, &requests); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
||||
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error {
|
||||
var requests []galleryModel
|
||||
err := json.Unmarshal([]byte(s), &requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processRequests(modelPath, s, cm, galleries, requests)
|
||||
}
|
29
core/services/metrics.go
Normal file
29
core/services/metrics.go
Normal file
@ -0,0 +1,29 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||
api "go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric"
|
||||
)
|
||||
|
||||
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
|
||||
// If it does not return an error, make sure to call shutdown for proper cleanup.
|
||||
func SetupMetrics() (*schema.LocalAIMetrics, error) {
|
||||
exporter, err := prometheus.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||
meter := provider.Meter("github.com/go-skynet/LocalAI")
|
||||
|
||||
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &schema.LocalAIMetrics{
|
||||
Meter: meter,
|
||||
ApiTimeMetric: apiTimeMetric,
|
||||
}, nil
|
||||
}
|
100
core/startup/config_file_watcher.go
Normal file
100
core/startup/config_file_watcher.go
Normal file
@ -0,0 +1,100 @@
|
||||
package startup
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type WatchConfigDirectoryCloser func() error
|
||||
|
||||
func ReadApiKeysJson(configDir string, options *schema.StartupOptions) error {
|
||||
fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json"))
|
||||
if err == nil {
|
||||
// Parse JSON content from the file
|
||||
var fileKeys []string
|
||||
err := json.Unmarshal(fileContent, &fileKeys)
|
||||
if err == nil {
|
||||
options.ApiKeys = append(options.ApiKeys, fileKeys...)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ReadExternalBackendsJson(configDir string, options *schema.StartupOptions) error {
|
||||
fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Parse JSON content from the file
|
||||
var fileBackends map[string]string
|
||||
err = json.Unmarshal(fileContent, &fileBackends)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = mergo.Merge(&options.ExternalGRPCBackends, fileBackends)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var CONFIG_FILE_UPDATES = map[string]func(configDir string, options *schema.StartupOptions) error{
|
||||
"api_keys.json": ReadApiKeysJson,
|
||||
"external_backends.json": ReadExternalBackendsJson,
|
||||
}
|
||||
|
||||
func WatchConfigDirectory(configDir string, options *schema.StartupOptions) (WatchConfigDirectoryCloser, error) {
|
||||
if len(configDir) == 0 {
|
||||
return nil, fmt.Errorf("configDir blank")
|
||||
}
|
||||
configWatcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err)
|
||||
}
|
||||
ret := func() error {
|
||||
configWatcher.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start listening for events.
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-configWatcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write) {
|
||||
for targetName, watchFn := range CONFIG_FILE_UPDATES {
|
||||
if event.Name == targetName {
|
||||
err := watchFn(configDir, options)
|
||||
log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
case _, ok := <-configWatcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Add a path.
|
||||
err = configWatcher.Add(configDir)
|
||||
if err != nil {
|
||||
return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
93
core/startup/startup.go
Normal file
93
core/startup/startup.go
Normal file
@ -0,0 +1,93 @@
|
||||
package startup
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/pkg/assets"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Startup(opts ...schema.AppOption) (*services.ConfigLoader, *model.ModelLoader, *schema.StartupOptions, error) {
|
||||
options := schema.NewStartupOptions(opts...)
|
||||
|
||||
ml := model.NewModelLoader(options.ModelPath)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
if options.Debug {
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
|
||||
cl := services.NewConfigLoader()
|
||||
if err := cl.LoadConfigs(options.ModelPath); err != nil {
|
||||
log.Error().Msgf("error loading config files: %s", err.Error())
|
||||
}
|
||||
|
||||
if options.ConfigFile != "" {
|
||||
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
|
||||
log.Error().Msgf("error loading config file: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := cl.Preload(options.ModelPath); err != nil {
|
||||
log.Error().Msgf("error downloading models: %s", err.Error())
|
||||
}
|
||||
|
||||
if options.PreloadJSONModels != "" {
|
||||
if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.PreloadModelsFromPath != "" {
|
||||
if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if options.Debug {
|
||||
for _, v := range cl.ListConfigs() {
|
||||
cfg, _ := cl.GetConfig(v)
|
||||
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
if options.AssetsDestination != "" {
|
||||
// Extract files from the embedded FS
|
||||
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
|
||||
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
|
||||
if err != nil {
|
||||
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
ml.StopAllGRPC()
|
||||
}()
|
||||
|
||||
if options.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
ml,
|
||||
options.WatchDogBusyTimeout,
|
||||
options.WatchDogIdleTimeout,
|
||||
options.WatchDogBusy,
|
||||
options.WatchDogIdle)
|
||||
ml.SetWatchDog(wd)
|
||||
go wd.Run()
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
wd.Shutdown()
|
||||
}()
|
||||
}
|
||||
|
||||
return cl, ml, options, nil
|
||||
}
|
@ -17,6 +17,53 @@ This section will collect how-to, notes and development documentation
|
||||
|
||||
We use conventional commits and semantic versioning. Please follow the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) specification when writing commit messages.
|
||||
|
||||
## LocalAI Project Structure
|
||||
|
||||
**LocalAI is made of multiple components, developed in multiple repositories:**
|
||||
|
||||
The core repository, containing the primary `local-ai` server code, gRPC stubs, this documentation website, and docker container building resources are all located at [mudler/LocalAI](https://github.com/mudler/LocalAI).
|
||||
|
||||
As LocalAI is designed to make use of multiple, independent model galleries, those are maintained seperately. The following public model galleries are available for use:
|
||||
|
||||
* [go-skynet/model-gallery](https://github.com/go-skynet/model-gallery) - The original gallery, the `golang` huggingface scraper ran into limits and was largely retired, so this now holds handmade yaml configs
|
||||
* [dave-gray101/model-gallery](https://github.com/dave-gray101/model-gallery) - An automated gallery designed to track HuggingFace uploads and produce best-effort automatically generated configurations for LocalAI. It is designed to produce one LocalAI gallery per repository on HuggingFace.
|
||||
|
||||
### Directory Structure of this Repo
|
||||
|
||||
The core repository is broken up into the following primary chunks:
|
||||
|
||||
* `/backend`: gRPC protobuf specification and gRPC backends. Subfolders for each language.
|
||||
* **`/core`**: golang sourcecode for the core LocalAI application. Broken down below.
|
||||
* `/docs`: localai.io website that you are reading now
|
||||
* `/examples`: example code integrating LocalAI to other projects and/or developer samples and tools
|
||||
* `/internal`: **here be dragons**. Don't touch this, it's used for automatic versioning.
|
||||
* `/models`: _No code here!_ This is where models are installed!
|
||||
* **`/pkg`**: golang sourcecode that is intended to be reusable or at least widely imported across LocalAI. Broken down below
|
||||
* `/prompt-templates`: _No code here!_ This is where **example** prompt templates were historically stored. Somewhat obsolete these days, model-galleries tend to replace manually creating these?
|
||||
* `/tests`: Does what it says on the tin. Please write tests and put them here when you do.
|
||||
|
||||
The `core` folder is broken down further:
|
||||
|
||||
* **`/core/backend`**: code that interacts with a gRPC backend to perform AI tasks.
|
||||
* `/core/http`: code specifically related to the REST server
|
||||
* `/core/http/endpoints`: Has two subdirectories, `openai` and `localai` for binding the respective endpoints to the correct backend or service.
|
||||
* `/core/mqtt`: core specifically related to the MQTT server. Stub for now. Coming soon!
|
||||
* **`/core/services`**: code implementing functionality performed by `local-ai` itself, rather than delegated to a backend.
|
||||
* `/core/startup`: code related specifically to application startup of `local-ai`. Potentially to be refactored to become a part of `/core/services` at a later date, or not.
|
||||
|
||||
The `pkg` folder is broken down further:
|
||||
|
||||
* `/pkg/assets`: Currently contains a single function related to extracting files from archives. Potentially to be refactored to become a part of `/core/utils` at a later date?
|
||||
* **`/pkg/datamodel`**: Contains the data types and definitions used by the LocalAI project. Imported widely!
|
||||
* `/pkg/gallery`: Code related to interacting with a `model-gallery`
|
||||
* `/pkg/grammar`: Code related to BNF / functions for LLM
|
||||
* `/pkg/grpc`: base classes and interfaces for gRPC backends to implement
|
||||
* `/pkg/langchain`: langchain related code in golang
|
||||
* **`/pkg/model`**: Code related to loading and initializing a model and creating the appropriate gRPC backend.
|
||||
* `/pkg/stablediffusion`: Code related to stablediffusion in golang.
|
||||
* `/pkg/utils`: Every real programmer knows what they are going to find in here... it's our junk drawer of utility functions.
|
||||
|
||||
|
||||
## Creating a gRPC backend
|
||||
|
||||
LocalAI backends are `gRPC` servers.
|
||||
|
@ -20,7 +20,7 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
|
||||
Returns an `audio/wav` file.
|
||||
|
||||
#### Setup
|
||||
#### Text-To-Speech Setup
|
||||
|
||||
LocalAI supports [bark]({{%relref "model-compatibility/bark" %}}) , `piper` and `vall-e-x`:
|
||||
|
||||
@ -52,6 +52,8 @@ Note:
|
||||
- The model name is case sensitive.
|
||||
- LocalAI must be compiled with the `GO_TAGS=tts` flag.
|
||||
|
||||
#### Music
|
||||
|
||||
LocalAI also has experimental support for `transformers-musicgen` for the generation of short musical compositions. Currently, this is implemented via the same requests used for text to speech:
|
||||
|
||||
```
|
||||
@ -62,7 +64,8 @@ curl --request POST \
|
||||
"backend": "transformers-musicgen",
|
||||
"model": "facebook/musicgen-medium",
|
||||
"input": "Cello Rave"
|
||||
}' | aplay```
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
Future versions of LocalAI will expose additional control over audio generation beyond the text prompt.
|
||||
|
||||
|
4
go.mod
4
go.mod
@ -5,6 +5,7 @@ go 1.21
|
||||
require (
|
||||
github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf
|
||||
github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1
|
||||
@ -15,7 +16,6 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hpcloud/tail v1.0.0
|
||||
github.com/imdario/mergo v0.3.16
|
||||
github.com/json-iterator/go v1.1.12
|
||||
github.com/mholt/archiver/v3 v3.5.1
|
||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
|
||||
github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af
|
||||
@ -63,8 +63,6 @@ require (
|
||||
github.com/klauspost/pgzip v1.2.5 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/nwaples/rardecode v1.1.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.2 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
|
||||
|
11
go.sum
11
go.sum
@ -24,8 +24,9 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L
|
||||
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
|
||||
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ=
|
||||
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo=
|
||||
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||
@ -74,7 +75,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||
@ -88,8 +88,6 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=
|
||||
github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw=
|
||||
github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
|
||||
github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
@ -119,11 +117,6 @@ github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Cl
|
||||
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU=
|
||||
github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig=
|
||||
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI=
|
||||
|
120
main.go
120
main.go
@ -12,14 +12,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
api "github.com/go-skynet/LocalAI/api"
|
||||
"github.com/go-skynet/LocalAI/api/backend"
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/core/backend"
|
||||
"github.com/go-skynet/LocalAI/core/http"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/core/startup"
|
||||
"github.com/go-skynet/LocalAI/internal"
|
||||
"github.com/go-skynet/LocalAI/metrics"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
progressbar "github.com/schollz/progressbar/v3"
|
||||
@ -190,6 +190,12 @@ func main() {
|
||||
EnvVars: []string{"PRELOAD_BACKEND_ONLY"},
|
||||
Value: false,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "localai-config-dir",
|
||||
Usage: "Directory to use for the configuration files of LocalAI itself. This is NOT where model files should be placed.",
|
||||
EnvVars: []string{"LOCALAI_CONFIG_DIR"},
|
||||
Value: "./config",
|
||||
},
|
||||
},
|
||||
Description: `
|
||||
LocalAI is a drop-in replacement OpenAI API which runs inference locally.
|
||||
@ -208,54 +214,54 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
UsageText: `local-ai [options]`,
|
||||
Copyright: "Ettore Di Giacinto",
|
||||
Action: func(ctx *cli.Context) error {
|
||||
opts := []options.AppOption{
|
||||
options.WithConfigFile(ctx.String("config-file")),
|
||||
options.WithJSONStringPreload(ctx.String("preload-models")),
|
||||
options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||
options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
|
||||
options.WithContextSize(ctx.Int("context-size")),
|
||||
options.WithDebug(ctx.Bool("debug")),
|
||||
options.WithImageDir(ctx.String("image-path")),
|
||||
options.WithAudioDir(ctx.String("audio-path")),
|
||||
options.WithF16(ctx.Bool("f16")),
|
||||
options.WithStringGalleries(ctx.String("galleries")),
|
||||
options.WithDisableMessage(false),
|
||||
options.WithCors(ctx.Bool("cors")),
|
||||
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
||||
options.WithThreads(ctx.Int("threads")),
|
||||
options.WithBackendAssets(backendAssets),
|
||||
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||
options.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||
options.WithApiKeys(ctx.StringSlice("api-keys")),
|
||||
options.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...),
|
||||
opts := []schema.AppOption{
|
||||
schema.WithConfigFile(ctx.String("config-file")),
|
||||
schema.WithJSONStringPreload(ctx.String("preload-models")),
|
||||
schema.WithYAMLConfigPreload(ctx.String("preload-models-config")),
|
||||
schema.WithModelPath(ctx.String("models-path")),
|
||||
schema.WithContextSize(ctx.Int("context-size")),
|
||||
schema.WithDebug(ctx.Bool("debug")),
|
||||
schema.WithImageDir(ctx.String("image-path")),
|
||||
schema.WithAudioDir(ctx.String("audio-path")),
|
||||
schema.WithF16(ctx.Bool("f16")),
|
||||
schema.WithStringGalleries(ctx.String("galleries")),
|
||||
schema.WithDisableMessage(false),
|
||||
schema.WithCors(ctx.Bool("cors")),
|
||||
schema.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
|
||||
schema.WithThreads(ctx.Int("threads")),
|
||||
schema.WithBackendAssets(backendAssets),
|
||||
schema.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
|
||||
schema.WithUploadLimitMB(ctx.Int("upload-limit")),
|
||||
schema.WithApiKeys(ctx.StringSlice("api-keys")),
|
||||
schema.WithModelsURL(append(ctx.StringSlice("models"), ctx.Args().Slice()...)...),
|
||||
}
|
||||
|
||||
idleWatchDog := ctx.Bool("enable-watchdog-idle")
|
||||
busyWatchDog := ctx.Bool("enable-watchdog-busy")
|
||||
if idleWatchDog || busyWatchDog {
|
||||
opts = append(opts, options.EnableWatchDog)
|
||||
opts = append(opts, schema.EnableWatchDog)
|
||||
if idleWatchDog {
|
||||
opts = append(opts, options.EnableWatchDogIdleCheck)
|
||||
opts = append(opts, schema.EnableWatchDogIdleCheck)
|
||||
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts = append(opts, options.SetWatchDogIdleTimeout(dur))
|
||||
opts = append(opts, schema.SetWatchDogIdleTimeout(dur))
|
||||
}
|
||||
if busyWatchDog {
|
||||
opts = append(opts, options.EnableWatchDogBusyCheck)
|
||||
opts = append(opts, schema.EnableWatchDogBusyCheck)
|
||||
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts = append(opts, options.SetWatchDogBusyTimeout(dur))
|
||||
opts = append(opts, schema.SetWatchDogBusyTimeout(dur))
|
||||
}
|
||||
}
|
||||
if ctx.Bool("parallel-requests") {
|
||||
opts = append(opts, options.EnableParallelBackendRequests)
|
||||
opts = append(opts, schema.EnableParallelBackendRequests)
|
||||
}
|
||||
if ctx.Bool("single-active-backend") {
|
||||
opts = append(opts, options.EnableSingleBackend)
|
||||
opts = append(opts, schema.EnableSingleBackend)
|
||||
}
|
||||
|
||||
externalgRPC := ctx.StringSlice("external-grpc-backends")
|
||||
@ -263,30 +269,42 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
for _, v := range externalgRPC {
|
||||
backend := v[:strings.IndexByte(v, ':')]
|
||||
uri := v[strings.IndexByte(v, ':')+1:]
|
||||
opts = append(opts, options.WithExternalBackend(backend, uri))
|
||||
opts = append(opts, schema.WithExternalBackend(backend, uri))
|
||||
}
|
||||
|
||||
if ctx.Bool("autoload-galleries") {
|
||||
opts = append(opts, options.EnableGalleriesAutoload)
|
||||
opts = append(opts, schema.EnableGalleriesAutoload)
|
||||
}
|
||||
|
||||
if ctx.Bool("preload-backend-only") {
|
||||
_, _, err := api.Startup(opts...)
|
||||
_, _, _, err := startup.Startup(opts...)
|
||||
return err
|
||||
}
|
||||
|
||||
metrics, err := metrics.SetupMetrics()
|
||||
metrics, err := services.SetupMetrics()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts = append(opts, options.WithMetrics(metrics))
|
||||
opts = append(opts, schema.WithMetrics(metrics))
|
||||
|
||||
app, err := api.App(opts...)
|
||||
cl, ml, options, err := startup.Startup(opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed basic startup tasks with error %s", err.Error())
|
||||
}
|
||||
|
||||
closeConfigWatcherFn, err := startup.WatchConfigDirectory(ctx.String("localai-config-dir"), options)
|
||||
|
||||
defer closeConfigWatcherFn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed while watching configuration directory %s", ctx.String("localai-config-dir"))
|
||||
}
|
||||
|
||||
appHTTP, err := http.App(cl, ml, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return app.Listen(ctx.String("address"))
|
||||
return appHTTP.Listen(ctx.String("address"))
|
||||
},
|
||||
Commands: []*cli.Command{
|
||||
{
|
||||
@ -384,16 +402,18 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
|
||||
text := strings.Join(ctx.Args().Slice(), " ")
|
||||
|
||||
opts := &options.Option{
|
||||
Loader: model.NewModelLoader(ctx.String("models-path")),
|
||||
opts := &schema.StartupOptions{
|
||||
ModelPath: ctx.String("models-path"),
|
||||
Context: context.Background(),
|
||||
AudioDir: outputDir,
|
||||
AssetsDestination: ctx.String("backend-assets-path"),
|
||||
}
|
||||
|
||||
defer opts.Loader.StopAllGRPC()
|
||||
loader := model.NewModelLoader(opts.ModelPath)
|
||||
|
||||
filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, opts.Loader, opts)
|
||||
defer loader.StopAllGRPC()
|
||||
|
||||
filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, loader, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -446,13 +466,15 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
language := ctx.String("language")
|
||||
threads := ctx.Int("threads")
|
||||
|
||||
opts := &options.Option{
|
||||
Loader: model.NewModelLoader(ctx.String("models-path")),
|
||||
opts := &schema.StartupOptions{
|
||||
ModelPath: ctx.String("models-path"),
|
||||
Context: context.Background(),
|
||||
AssetsDestination: ctx.String("backend-assets-path"),
|
||||
}
|
||||
|
||||
cl := config.NewConfigLoader()
|
||||
ml := model.NewModelLoader(opts.ModelPath)
|
||||
|
||||
cl := services.NewConfigLoader()
|
||||
if err := cl.LoadConfigs(ctx.String("models-path")); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -464,9 +486,9 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||
|
||||
c.Threads = threads
|
||||
|
||||
defer opts.Loader.StopAllGRPC()
|
||||
defer ml.StopAllGRPC()
|
||||
|
||||
tr, err := backend.ModelTranscription(filename, language, opts.Loader, c, opts)
|
||||
tr, err := backend.ModelTranscription(filename, language, ml, c, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,83 +0,0 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||
api "go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
meter api.Meter
|
||||
apiTimeMetric api.Float64Histogram
|
||||
}
|
||||
|
||||
// setupOTelSDK bootstraps the OpenTelemetry pipeline.
|
||||
// If it does not return an error, make sure to call shutdown for proper cleanup.
|
||||
func SetupMetrics() (*Metrics, error) {
|
||||
exporter, err := prometheus.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||
meter := provider.Meter("github.com/go-skynet/LocalAI")
|
||||
|
||||
apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Metrics{
|
||||
meter: meter,
|
||||
apiTimeMetric: apiTimeMetric,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func MetricsHandler() fiber.Handler {
|
||||
return adaptor.HTTPHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c *fiber.Ctx) bool
|
||||
metrics *Metrics
|
||||
}
|
||||
|
||||
func APIMiddleware(metrics *Metrics) fiber.Handler {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metrics: metrics,
|
||||
Filter: func(c *fiber.Ctx) bool {
|
||||
if c.Path() == "/metrics" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return c.Next()
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Method()
|
||||
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metrics.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) ObserveAPICall(method string, path string, duration float64) {
|
||||
opts := api.WithAttributes(
|
||||
attribute.String("method", method),
|
||||
attribute.String("path", path),
|
||||
)
|
||||
m.apiTimeMetric.Record(context.Background(), duration, opts)
|
||||
}
|
@ -22,11 +22,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
||||
applyModel := func(model *GalleryModel) error {
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
var config Config
|
||||
var config InstallableModel
|
||||
|
||||
if len(model.URL) > 0 {
|
||||
var err error
|
||||
config, err = GetGalleryConfigFromURL(model.URL)
|
||||
config, err = GetInstallableModelFromURL(model.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -36,7 +36,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config = Config{
|
||||
config = InstallableModel{
|
||||
ConfigFile: string(reYamlConfig),
|
||||
Description: model.Description,
|
||||
License: model.License,
|
||||
|
@ -1,13 +1,9 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/imdario/mergo"
|
||||
@ -41,9 +37,9 @@ prompt_templates:
|
||||
content: ""
|
||||
|
||||
*/
|
||||
// Config is the model configuration which contains all the model details
|
||||
// InstallableModel is the model configuration which contains all the model details
|
||||
// This configuration is read from the gallery endpoint and is used to download and install the model
|
||||
type Config struct {
|
||||
type InstallableModel struct {
|
||||
Description string `yaml:"description"`
|
||||
License string `yaml:"license"`
|
||||
URLs []string `yaml:"urls"`
|
||||
@ -64,8 +60,8 @@ type PromptTemplate struct {
|
||||
Content string `yaml:"content"`
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURL(url string) (Config, error) {
|
||||
var config Config
|
||||
func GetInstallableModelFromURL(url string) (InstallableModel, error) {
|
||||
var config InstallableModel
|
||||
err := utils.GetURI(url, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
@ -76,7 +72,7 @@ func GetGalleryConfigFromURL(url string) (Config, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func ReadConfigFile(filePath string) (*Config, error) {
|
||||
func ReadInstallableModelFile(filePath string) (*InstallableModel, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
@ -84,7 +80,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
|
||||
}
|
||||
|
||||
// Unmarshal YAML data into a Config struct
|
||||
var config Config
|
||||
var config InstallableModel
|
||||
err = yaml.Unmarshal(yamlFile, &config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
|
||||
@ -93,7 +89,7 @@ func ReadConfigFile(filePath string) (*Config, error) {
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
||||
func InstallModel(basePath, nameOverride string, config *InstallableModel, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0755)
|
||||
if err != nil {
|
||||
@ -183,54 +179,3 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
total int64
|
||||
written int64
|
||||
downloadStatus func(string, string, string, float64)
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = pw.hash.Write(p)
|
||||
pw.written += int64(n)
|
||||
|
||||
if pw.total > 0 {
|
||||
percentage := float64(pw.written) / float64(pw.total) * 100
|
||||
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
|
||||
} else {
|
||||
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return strconv.FormatInt(bytes, 10) + " B"
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func calculateSHA(filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, file); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ var _ = Describe("Model test", func() {
|
||||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||
@ -87,7 +87,7 @@ var _ = Describe("Model test", func() {
|
||||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||
@ -103,7 +103,7 @@ var _ = Describe("Model test", func() {
|
||||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {})
|
||||
@ -129,7 +129,7 @@ var _ = Describe("Model test", func() {
|
||||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {})
|
||||
|
18
pkg/gallery/op.go
Normal file
18
pkg/gallery/op.go
Normal file
@ -0,0 +1,18 @@
|
||||
package gallery
|
||||
|
||||
type GalleryOp struct {
|
||||
Req GalleryModel
|
||||
Id string
|
||||
Galleries []Gallery
|
||||
GalleryName string
|
||||
}
|
||||
|
||||
type GalleryOpStatus struct {
|
||||
FileName string `json:"file_name"`
|
||||
Error error `json:"error"`
|
||||
Processed bool `json:"processed"`
|
||||
Message string `json:"message"`
|
||||
Progress float64 `json:"progress"`
|
||||
TotalFileSize string `json:"file_size"`
|
||||
DownloadedFileSize string `json:"downloaded_size"`
|
||||
}
|
@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() {
|
||||
Context("requests", func() {
|
||||
It("parses github with a branch", func() {
|
||||
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
|
||||
e, err := GetGalleryConfigFromURL(req.URL)
|
||||
e, err := GetInstallableModelFromURL(req.URL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(e.Name).To(Equal("gpt4all-j"))
|
||||
})
|
||||
|
@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
@ -53,8 +53,9 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) {
|
||||
return schema.Result{}, fmt.Errorf("unimplemented")
|
||||
// TODO CHECK THIS
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error) {
|
||||
return schema.WhisperResult{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) TTS(*pb.TTSRequest) error {
|
||||
|
@ -7,8 +7,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
@ -223,7 +223,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
|
||||
return client.TTS(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) {
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.WhisperResult, error) {
|
||||
if !c.parallel {
|
||||
c.opMutex.Lock()
|
||||
defer c.opMutex.Unlock()
|
||||
@ -244,14 +244,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tresult := &schema.Result{}
|
||||
tresult := &schema.WhisperResult{}
|
||||
for _, s := range res.Segments {
|
||||
tks := []int{}
|
||||
for _, t := range s.Tokens {
|
||||
tks = append(tks, int(t))
|
||||
}
|
||||
tresult.Segments = append(tresult.Segments,
|
||||
schema.Segment{
|
||||
schema.WhisperSegment{
|
||||
Text: s.Text,
|
||||
Id: int(s.Id),
|
||||
Start: time.Duration(s.Start),
|
||||
|
@ -1,8 +1,8 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
@ -15,7 +15,7 @@ type LLM interface {
|
||||
Load(*pb.ModelOptions) error
|
||||
Embeddings(*pb.PredictOptions) ([]float32, error)
|
||||
GenerateImage(*pb.GenerateImageRequest) error
|
||||
AudioTranscription(*pb.TranscriptRequest) (schema.Result, error)
|
||||
AudioTranscription(*pb.TranscriptRequest) (schema.WhisperResult, error)
|
||||
TTS(*pb.TTSRequest) error
|
||||
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
|
||||
Status() (pb.StatusResponse, error)
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.6.1
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v4.26.0
|
||||
// source: backend.proto
|
||||
|
||||
package proto
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v3.6.1
|
||||
// - protoc-gen-go-grpc v1.3.0
|
||||
// - protoc v4.26.0
|
||||
// source: backend.proto
|
||||
|
||||
package proto
|
||||
@ -18,6 +18,19 @@ import (
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
const (
|
||||
Backend_Health_FullMethodName = "/backend.Backend/Health"
|
||||
Backend_Predict_FullMethodName = "/backend.Backend/Predict"
|
||||
Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel"
|
||||
Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream"
|
||||
Backend_Embedding_FullMethodName = "/backend.Backend/Embedding"
|
||||
Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage"
|
||||
Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription"
|
||||
Backend_TTS_FullMethodName = "/backend.Backend/TTS"
|
||||
Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString"
|
||||
Backend_Status_FullMethodName = "/backend.Backend/Status"
|
||||
)
|
||||
|
||||
// BackendClient is the client API for Backend service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
@ -44,7 +57,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
|
||||
|
||||
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
|
||||
out := new(Reply)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -53,7 +66,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g
|
||||
|
||||
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
|
||||
out := new(Reply)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -62,7 +75,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ..
|
||||
|
||||
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -70,7 +83,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ..
|
||||
}
|
||||
|
||||
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -103,7 +116,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) {
|
||||
|
||||
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
|
||||
out := new(EmbeddingResult)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -112,7 +125,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts
|
||||
|
||||
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -121,7 +134,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ
|
||||
|
||||
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
|
||||
out := new(TranscriptResult)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -130,7 +143,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe
|
||||
|
||||
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
|
||||
out := new(Result)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -139,7 +152,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca
|
||||
|
||||
func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) {
|
||||
out := new(TokenizationResponse)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -148,7 +161,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions,
|
||||
|
||||
func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) {
|
||||
out := new(StatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -229,7 +242,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Health",
|
||||
FullMethod: Backend_Health_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
|
||||
@ -247,7 +260,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Predict",
|
||||
FullMethod: Backend_Predict_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
|
||||
@ -265,7 +278,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/LoadModel",
|
||||
FullMethod: Backend_LoadModel_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
|
||||
@ -304,7 +317,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Embedding",
|
||||
FullMethod: Backend_Embedding_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
|
||||
@ -322,7 +335,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/GenerateImage",
|
||||
FullMethod: Backend_GenerateImage_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
|
||||
@ -340,7 +353,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/AudioTranscription",
|
||||
FullMethod: Backend_AudioTranscription_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
|
||||
@ -358,7 +371,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/TTS",
|
||||
FullMethod: Backend_TTS_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
|
||||
@ -376,7 +389,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/TokenizeString",
|
||||
FullMethod: Backend_TokenizeString_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions))
|
||||
@ -394,7 +407,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/backend.Backend/Status",
|
||||
FullMethod: Backend_Status_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BackendServer).Status(ctx, req.(*HealthMessage))
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -71,7 +71,7 @@ var AutoLoadBackends []string = []string{
|
||||
|
||||
// starts the grpcModelProcess for the backend, and returns a grpc client
|
||||
// It also loads the model
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) {
|
||||
func (ml *ModelLoader) grpcModel(backend string, o *ModelOptions) func(string, string) (ModelAddress, error) {
|
||||
return func(modelName, modelFile string) (ModelAddress, error) {
|
||||
log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o)
|
||||
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
grammar "github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc"
|
||||
process "github.com/mudler/go-processmanager"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
type ModelOptions struct {
|
||||
backendString string
|
||||
model string
|
||||
threads uint32
|
||||
@ -23,14 +23,14 @@ type Options struct {
|
||||
parallelRequests bool
|
||||
}
|
||||
|
||||
type Option func(*Options)
|
||||
type Option func(*ModelOptions)
|
||||
|
||||
var EnableParallelRequests = func(o *Options) {
|
||||
var EnableParallelRequests = func(o *ModelOptions) {
|
||||
o.parallelRequests = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
if o.externalBackends == nil {
|
||||
o.externalBackends = make(map[string]string)
|
||||
}
|
||||
@ -38,62 +38,81 @@ func WithExternalBackend(name string, uri string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// Currently, LocalAI isn't ready for backends to be yanked out from under it - so this is a little overcomplicated to allow non-overwriting updates
|
||||
func WithExternalBackends(backends map[string]string, overwrite bool) Option {
|
||||
return func(o *ModelOptions) {
|
||||
if backends == nil {
|
||||
return
|
||||
}
|
||||
if o.externalBackends == nil {
|
||||
o.externalBackends = backends
|
||||
return
|
||||
}
|
||||
for name, url := range backends {
|
||||
_, exists := o.externalBackends[name]
|
||||
if !exists || overwrite {
|
||||
o.externalBackends[name] = url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithGRPCAttempts(attempts int) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.grpcAttempts = attempts
|
||||
}
|
||||
}
|
||||
|
||||
func WithGRPCAttemptsDelay(delay int) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.grpcAttemptsDelay = delay
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendString(backend string) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.backendString = backend
|
||||
}
|
||||
}
|
||||
|
||||
func WithModel(modelFile string) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.model = modelFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.gRPCOptions = opts
|
||||
}
|
||||
}
|
||||
|
||||
func WithThreads(threads uint32) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.threads = threads
|
||||
}
|
||||
}
|
||||
|
||||
func WithAssetDir(assetDir string) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.assetDir = assetDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithSingleActiveBackend() Option {
|
||||
return func(o *Options) {
|
||||
return func(o *ModelOptions) {
|
||||
o.singleActiveBackend = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewOptions(opts ...Option) *Options {
|
||||
o := &Options{
|
||||
func NewOptions(opts ...Option) *ModelOptions {
|
||||
o := &ModelOptions{
|
||||
gRPCOptions: &pb.ModelOptions{},
|
||||
context: context.Background(),
|
||||
grpcAttempts: 20,
|
||||
|
@ -1,16 +1,11 @@
|
||||
package api_config
|
||||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@ -152,11 +147,6 @@ type TemplateConfig struct {
|
||||
Functions string `yaml:"function"`
|
||||
}
|
||||
|
||||
type ConfigLoader struct {
|
||||
configs map[string]Config
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *Config) SetFunctionCallString(s string) {
|
||||
c.functionCallString = s
|
||||
}
|
||||
@ -193,11 +183,6 @@ func DefaultConfig(modelFile string) *Config {
|
||||
}
|
||||
}
|
||||
|
||||
func NewConfigLoader() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
configs: make(map[string]Config),
|
||||
}
|
||||
}
|
||||
func ReadConfigFile(file string) ([]*Config, error) {
|
||||
c := &[]*Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
@ -211,7 +196,7 @@ func ReadConfigFile(file string) ([]*Config, error) {
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
func ReadConfig(file string) (*Config, error) {
|
||||
func ReadSingleConfigFile(file string) (*Config, error) {
|
||||
c := &Config{}
|
||||
f, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
@ -224,136 +209,192 @@ func ReadConfig(file string) (*Config, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) LoadConfigFile(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfigFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot load config file: %w", err)
|
||||
func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != 0 {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
for _, cc := range c {
|
||||
cm.configs[cc.Name] = *cc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) LoadConfig(file string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
c, err := ReadConfig(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read config file: %w", err)
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
cm.configs[c.Name] = *c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
v, exists := cm.configs[m]
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) GetAllConfigs() []Config {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []Config
|
||||
for _, v := range cm.configs {
|
||||
res = append(res, v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) ListConfigs() []string {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
var res []string
|
||||
for k := range cm.configs {
|
||||
res = append(res, k)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Preload prepare models if they are not local but url or huggingface repositories
|
||||
func (cm *ConfigLoader) Preload(modelPath string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
|
||||
status := func(fileName, current, total string, percent float64) {
|
||||
utils.DisplayDownloadFunction(fileName, current, total, percent)
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
log.Info().Msgf("Preloading models from %s", modelPath)
|
||||
if input.ModelBaseName != "" {
|
||||
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
||||
}
|
||||
|
||||
for i, config := range cm.configs {
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
// Download files and verify their SHA
|
||||
for _, file := range config.DownloadFiles {
|
||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||
if input.UseFastTokenizer {
|
||||
config.UseFastTokenizer = input.UseFastTokenizer
|
||||
}
|
||||
|
||||
if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Create file path
|
||||
filePath := filepath.Join(modelPath, file.Filename)
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
|
||||
return err
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != 0 {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != 0 {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.F16 {
|
||||
config.F16 = input.F16
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != 0 {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.Mirostat != 0 {
|
||||
config.LLMConfig.Mirostat = input.Mirostat
|
||||
}
|
||||
|
||||
if input.MirostatETA != 0 {
|
||||
config.LLMConfig.MirostatETA = input.MirostatETA
|
||||
}
|
||||
|
||||
if input.MirostatTAU != 0 {
|
||||
config.LLMConfig.MirostatTAU = input.MirostatTAU
|
||||
}
|
||||
|
||||
if input.TypicalP != 0 {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
modelURL := config.PredictionOptions.Model
|
||||
modelURL = utils.ConvertURL(modelURL)
|
||||
// Decode each request's message content
|
||||
index := 0
|
||||
for i, m := range input.Messages {
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
for _, pp := range c {
|
||||
if pp.Type == "text" {
|
||||
input.Messages[i].StringContent = pp.Text
|
||||
} else if pp.Type == "image_url" {
|
||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||
base64, err := utils.GetBase64Image(pp.ImageURL.URL)
|
||||
if err == nil {
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
// set a placeholder for each image
|
||||
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
||||
index++
|
||||
} else {
|
||||
fmt.Print("Failed encoding image", err)
|
||||
}
|
||||
|
||||
if utils.LooksLikeURL(modelURL) {
|
||||
// md5 of model name
|
||||
md5Name := utils.MD5(modelURL)
|
||||
|
||||
// check if file exists
|
||||
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
|
||||
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cc := cm.configs[i]
|
||||
c := &cc
|
||||
c.PredictionOptions.Model = md5Name
|
||||
cm.configs[i] = *c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigLoader) LoadConfigs(path string) error {
|
||||
cm.Lock()
|
||||
defer cm.Unlock()
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files := make([]fs.FileInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files = append(files, info)
|
||||
}
|
||||
for _, file := range files {
|
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
|
||||
continue
|
||||
}
|
||||
c, err := ReadConfig(filepath.Join(path, file.Name()))
|
||||
if err == nil {
|
||||
cm.configs[c.Name] = *c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
// TODO: check that this was merged correctly? I _think_ it is?
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []interface{}:
|
||||
tokens := []int{}
|
||||
for _, ii := range i {
|
||||
tokens = append(tokens, int(ii.(float64)))
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -1,11 +1,10 @@
|
||||
package api_config_test
|
||||
package schema_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
. "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/core/services"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
Context("Test Read configuration functions", func() {
|
||||
configFile = os.Getenv("CONFIG_FILE")
|
||||
It("Test ReadConfigFile", func() {
|
||||
config, err := ReadConfigFile(configFile)
|
||||
config, err := schema.ReadConfigFile(configFile)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
@ -28,12 +27,8 @@ var _ = Describe("Test cases for config related functions", func() {
|
||||
})
|
||||
|
||||
It("Test LoadConfigs", func() {
|
||||
cm := NewConfigLoader()
|
||||
opts := options.NewOptions()
|
||||
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH"))
|
||||
options.WithModelLoader(modelLoader)(opts)
|
||||
|
||||
err := cm.LoadConfigs(opts.Loader.ModelPath)
|
||||
cm := services.NewConfigLoader()
|
||||
err := cm.LoadConfigs(os.Getenv("MODELS_PATH"))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(cm.ListConfigs()).ToNot(BeNil())
|
||||
|
39
pkg/schema/localai.go
Normal file
39
pkg/schema/localai.go
Normal file
@ -0,0 +1,39 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
gopsutil "github.com/shirou/gopsutil/v3/process"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
type BackendMonitorRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
}
|
||||
|
||||
type BackendMonitorResponse struct {
|
||||
MemoryInfo *gopsutil.MemoryInfoStat
|
||||
MemoryPercent float32
|
||||
CPUPercent float64
|
||||
}
|
||||
|
||||
type TTSRequest struct {
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Input string `json:"input" yaml:"input"`
|
||||
Backend string `json:"backend" yaml:"backend"`
|
||||
}
|
||||
|
||||
type LocalAIMetrics struct {
|
||||
Meter metric.Meter
|
||||
ApiTimeMetric metric.Float64Histogram
|
||||
}
|
||||
|
||||
func (m *LocalAIMetrics) ObserveAPICall(method string, path string, duration float64) {
|
||||
opts := metric.WithAttributes(
|
||||
attribute.String("method", method),
|
||||
attribute.String("path", path),
|
||||
)
|
||||
m.ApiTimeMetric.Record(context.Background(), duration, opts)
|
||||
}
|
@ -3,8 +3,6 @@ package schema
|
||||
import (
|
||||
"context"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
)
|
||||
|
||||
@ -90,7 +88,7 @@ type ChatCompletionResponseFormat struct {
|
||||
}
|
||||
|
||||
type OpenAIRequest struct {
|
||||
config.PredictionOptions
|
||||
PredictionOptions
|
||||
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
@ -1,4 +1,4 @@
|
||||
package api_config
|
||||
package schema
|
||||
|
||||
type PredictionOptions struct {
|
||||
|
@ -1,4 +1,4 @@
|
||||
package options
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -6,16 +6,14 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/metrics"
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
type StartupOptions struct {
|
||||
Context context.Context
|
||||
ConfigFile string
|
||||
Loader *model.ModelLoader
|
||||
ModelPath string
|
||||
UploadLimitMB, Threads, ContextSize int
|
||||
F16 bool
|
||||
Debug, DisableMessage bool
|
||||
@ -26,7 +24,7 @@ type Option struct {
|
||||
PreloadModelsFromPath string
|
||||
CORSAllowOrigins string
|
||||
ApiKeys []string
|
||||
Metrics *metrics.Metrics
|
||||
Metrics *LocalAIMetrics
|
||||
|
||||
Galleries []gallery.Gallery
|
||||
|
||||
@ -47,12 +45,14 @@ type Option struct {
|
||||
ModelsURL []string
|
||||
|
||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||
|
||||
LocalAIConfigDir string
|
||||
}
|
||||
|
||||
type AppOption func(*Option)
|
||||
type AppOption func(*StartupOptions)
|
||||
|
||||
func NewOptions(o ...AppOption) *Option {
|
||||
opt := &Option{
|
||||
func NewStartupOptions(o ...AppOption) *StartupOptions {
|
||||
opt := &StartupOptions{
|
||||
Context: context.Background(),
|
||||
UploadLimitMB: 15,
|
||||
Threads: 1,
|
||||
@ -67,57 +67,57 @@ func NewOptions(o ...AppOption) *Option {
|
||||
}
|
||||
|
||||
func WithModelsURL(urls ...string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.ModelsURL = urls
|
||||
}
|
||||
}
|
||||
|
||||
func WithCors(b bool) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.CORS = b
|
||||
}
|
||||
}
|
||||
|
||||
var EnableWatchDog = func(o *Option) {
|
||||
var EnableWatchDog = func(o *StartupOptions) {
|
||||
o.WatchDog = true
|
||||
}
|
||||
|
||||
var EnableWatchDogIdleCheck = func(o *Option) {
|
||||
var EnableWatchDogIdleCheck = func(o *StartupOptions) {
|
||||
o.WatchDog = true
|
||||
o.WatchDogIdle = true
|
||||
}
|
||||
|
||||
var EnableWatchDogBusyCheck = func(o *Option) {
|
||||
var EnableWatchDogBusyCheck = func(o *StartupOptions) {
|
||||
o.WatchDog = true
|
||||
o.WatchDogBusy = true
|
||||
}
|
||||
|
||||
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.WatchDogBusyTimeout = t
|
||||
}
|
||||
}
|
||||
|
||||
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.WatchDogIdleTimeout = t
|
||||
}
|
||||
}
|
||||
|
||||
var EnableSingleBackend = func(o *Option) {
|
||||
var EnableSingleBackend = func(o *StartupOptions) {
|
||||
o.SingleBackend = true
|
||||
}
|
||||
|
||||
var EnableParallelBackendRequests = func(o *Option) {
|
||||
var EnableParallelBackendRequests = func(o *StartupOptions) {
|
||||
o.ParallelBackendRequests = true
|
||||
}
|
||||
|
||||
var EnableGalleriesAutoload = func(o *Option) {
|
||||
var EnableGalleriesAutoload = func(o *StartupOptions) {
|
||||
o.AutoloadGalleries = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
if o.ExternalGRPCBackends == nil {
|
||||
o.ExternalGRPCBackends = make(map[string]string)
|
||||
}
|
||||
@ -126,25 +126,25 @@ func WithExternalBackend(name string, uri string) AppOption {
|
||||
}
|
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.CORSAllowOrigins = b
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssetsOutput(out string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.AssetsDestination = out
|
||||
}
|
||||
}
|
||||
|
||||
func WithBackendAssets(f embed.FS) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.BackendAssets = f
|
||||
}
|
||||
}
|
||||
|
||||
func WithStringGalleries(galls string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
if galls == "" {
|
||||
log.Debug().Msgf("no galleries to load")
|
||||
o.Galleries = []gallery.Gallery{}
|
||||
@ -159,96 +159,102 @@ func WithStringGalleries(galls string) AppOption {
|
||||
}
|
||||
|
||||
func WithGalleries(galleries []gallery.Gallery) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.Galleries = append(o.Galleries, galleries...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.Context = ctx
|
||||
}
|
||||
}
|
||||
|
||||
func WithYAMLConfigPreload(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.PreloadModelsFromPath = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithJSONStringPreload(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.PreloadJSONModels = configFile
|
||||
}
|
||||
}
|
||||
func WithConfigFile(configFile string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.ConfigFile = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithModelLoader(loader *model.ModelLoader) AppOption {
|
||||
return func(o *Option) {
|
||||
o.Loader = loader
|
||||
func WithModelPath(path string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.ModelPath = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithUploadLimitMB(limit int) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.UploadLimitMB = limit
|
||||
}
|
||||
}
|
||||
|
||||
func WithThreads(threads int) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.Threads = threads
|
||||
}
|
||||
}
|
||||
|
||||
func WithContextSize(ctxSize int) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.ContextSize = ctxSize
|
||||
}
|
||||
}
|
||||
|
||||
func WithF16(f16 bool) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.F16 = f16
|
||||
}
|
||||
}
|
||||
|
||||
func WithDebug(debug bool) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.Debug = debug
|
||||
}
|
||||
}
|
||||
|
||||
func WithDisableMessage(disableMessage bool) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.DisableMessage = disableMessage
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudioDir(audioDir string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.AudioDir = audioDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithImageDir(imageDir string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.ImageDir = imageDir
|
||||
}
|
||||
}
|
||||
|
||||
func WithApiKeys(apiKeys []string) AppOption {
|
||||
return func(o *Option) {
|
||||
return func(o *StartupOptions) {
|
||||
o.ApiKeys = apiKeys
|
||||
}
|
||||
}
|
||||
|
||||
func WithMetrics(meter *metrics.Metrics) AppOption {
|
||||
return func(o *Option) {
|
||||
o.Metrics = meter
|
||||
func WithMetrics(metrics *LocalAIMetrics) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.Metrics = metrics
|
||||
}
|
||||
}
|
||||
|
||||
func WithLocalAIConfigDir(configDir string) AppOption {
|
||||
return func(o *StartupOptions) {
|
||||
o.LocalAIConfigDir = configDir
|
||||
}
|
||||
}
|
@ -2,7 +2,7 @@ package schema
|
||||
|
||||
import "time"
|
||||
|
||||
type Segment struct {
|
||||
type WhisperSegment struct {
|
||||
Id int `json:"id"`
|
||||
Start time.Duration `json:"start"`
|
||||
End time.Duration `json:"end"`
|
||||
@ -10,7 +10,7 @@ type Segment struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Segments []Segment `json:"segments"`
|
||||
Text string `json:"text"`
|
||||
type WhisperResult struct {
|
||||
Segments []WhisperSegment `json:"segments"`
|
||||
Text string `json:"text"`
|
||||
}
|
81
pkg/utils/file.go
Normal file
81
pkg/utils/file.go
Normal file
@ -0,0 +1,81 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func CreateTempFileFromMultipartFile(file *multipart.FileHeader, tempDir string, tempPattern string) (string, error) {
|
||||
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Create a temporary file in the requested directory:
|
||||
outputFile, err := os.CreateTemp(tempDir, tempPattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer outputFile.Close()
|
||||
|
||||
if _, err := io.Copy(outputFile, f); err != nil {
|
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, outputFile, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return outputFile.Name(), nil
|
||||
}
|
||||
|
||||
func CreateTempFileFromBase64(base64data string, tempDir string, tempPattern string) (string, error) {
|
||||
if len(base64data) == 0 {
|
||||
return "", fmt.Errorf("base64data empty?")
|
||||
}
|
||||
//base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
decoded, err := base64.StdEncoding.DecodeString(base64data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Create a temporary file in the requested directory:
|
||||
outputFile, err := os.CreateTemp(tempDir, tempPattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer outputFile.Close()
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(decoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return outputFile.Name(), nil
|
||||
}
|
||||
|
||||
func CreateTempFileFromUrl(url string, tempDir string, tempPattern string) (string, error) {
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create the file
|
||||
out, err := os.CreateTemp(tempDir, tempPattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
return out.Name(), err
|
||||
}
|
@ -3,18 +3,38 @@ package utils
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
HuggingFacePrefix = "huggingface://"
|
||||
HTTPPrefix = "http://"
|
||||
HTTPSPrefix = "https://"
|
||||
GithubURI = "github:"
|
||||
GithubURI2 = "github://"
|
||||
)
|
||||
|
||||
func getRecognizedURIPrefixes() []string {
|
||||
return []string{
|
||||
HuggingFacePrefix,
|
||||
HTTPPrefix,
|
||||
HTTPSPrefix,
|
||||
GithubURI,
|
||||
GithubURI2,
|
||||
}
|
||||
}
|
||||
|
||||
func GetURI(url string, f func(url string, i []byte) error) error {
|
||||
url = ConvertURL(url)
|
||||
|
||||
@ -52,20 +72,8 @@ func GetURI(url string, f func(url string, i []byte) error) error {
|
||||
return f(url, body)
|
||||
}
|
||||
|
||||
const (
|
||||
HuggingFacePrefix = "huggingface://"
|
||||
HTTPPrefix = "http://"
|
||||
HTTPSPrefix = "https://"
|
||||
GithubURI = "github:"
|
||||
GithubURI2 = "github://"
|
||||
)
|
||||
|
||||
func LooksLikeURL(s string) bool {
|
||||
return strings.HasPrefix(s, HTTPPrefix) ||
|
||||
strings.HasPrefix(s, HTTPSPrefix) ||
|
||||
strings.HasPrefix(s, HuggingFacePrefix) ||
|
||||
strings.HasPrefix(s, GithubURI) ||
|
||||
strings.HasPrefix(s, GithubURI2)
|
||||
return slices.Contains(getRecognizedURIPrefixes(), s)
|
||||
}
|
||||
|
||||
func ConvertURL(s string) string {
|
||||
@ -241,6 +249,37 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func GetBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
type progressWriter struct {
|
||||
fileName string
|
||||
total int64
|
||||
|
@ -3,16 +3,16 @@ package integration_test
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/go-skynet/LocalAI/pkg/schema"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Integration Tests involving reflection in liue of code generation", func() {
|
||||
Context("config.TemplateConfig and model.TemplateType must stay in sync", func() {
|
||||
Context("schema.TemplateConfig and model.TemplateType must stay in sync", func() {
|
||||
|
||||
ttc := reflect.TypeOf(config.TemplateConfig{})
|
||||
ttc := reflect.TypeOf(schema.TemplateConfig{})
|
||||
|
||||
It("TemplateConfig and TemplateType should have the same number of valid values", func() {
|
||||
const lastValidTemplateType = model.IntegrationTestTemplate - 1
|
||||
|
Loading…
Reference in New Issue
Block a user