LocalAI/api/api.go

257 lines
8.0 KiB
Go
Raw Normal View History

2023-04-11 22:02:39 +00:00
package api
2023-04-11 21:43:43 +00:00
import (
2023-04-20 16:33:02 +00:00
"errors"
2023-08-18 19:23:14 +00:00
"fmt"
"strings"
2023-04-11 21:43:43 +00:00
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/localai"
"github.com/go-skynet/LocalAI/api/openai"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/metrics"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/go-skynet/LocalAI/pkg/model"
2023-04-11 21:43:43 +00:00
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
2023-05-05 13:53:57 +00:00
"github.com/gofiber/fiber/v2/middleware/logger"
2023-04-11 21:43:43 +00:00
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/rs/zerolog"
2023-04-20 16:33:02 +00:00
"github.com/rs/zerolog/log"
2023-04-11 21:43:43 +00:00
)
2023-08-18 19:23:14 +00:00
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
options := options.NewOptions(opts...)
2023-05-21 12:38:25 +00:00
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
2023-08-18 19:23:14 +00:00
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, err
}
}
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopAllGRPC()
}()
if options.WatchDog {
wd := model.NewWatchDog(
options.Loader,
options.WatchDogBusyTimeout,
options.WatchDogIdleTimeout,
options.WatchDogBusy,
options.WatchDogIdle)
options.Loader.SetWatchDog(wd)
go wd.Run()
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
wd.Shutdown()
}()
}
2023-08-18 19:23:14 +00:00
return options, cl, nil
}
func App(opts ...options.AppOption) (*fiber.App, error) {
options, cl, err := Startup(opts...)
if err != nil {
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}
2023-04-20 16:33:02 +00:00
// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.DisableMessage,
2023-04-20 16:33:02 +00:00
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
2023-04-20 16:33:02 +00:00
},
2023-04-11 21:43:43 +00:00
})
if options.Debug {
2023-05-05 13:53:57 +00:00
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
2023-04-20 16:33:02 +00:00
// Default middleware config
app.Use(recover.New())
if options.Metrics != nil {
app.Use(metrics.APIMiddleware(options.Metrics))
}
2023-05-21 12:38:25 +00:00
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(options.ApiKeys) > 0 {
authHeader := c.Get("Authorization")
if authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
}
apiKey := authHeaderParts[1]
validApiKey := false
for _, key := range options.ApiKeys {
if apiKey == key {
validApiKey = true
}
}
if !validApiKey {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
}
}
return c.Next()
}
if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
c = cors.New()
2023-05-21 12:38:25 +00:00
} else {
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
2023-05-21 12:38:25 +00:00
}
app.Use(c)
2023-05-21 12:38:25 +00:00
}
2023-04-20 16:33:02 +00:00
// LocalAI API endpoints
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
2023-08-18 19:23:14 +00:00
galleryService.Start(options.Context, cl)
app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
2023-04-20 16:33:02 +00:00
// openAI compatible API endpoint
// chat
2023-08-18 19:23:14 +00:00
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
2023-04-20 16:33:02 +00:00
// edit
2023-08-18 19:23:14 +00:00
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
2023-04-29 07:22:09 +00:00
// completion
2023-08-18 19:23:14 +00:00
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
2023-04-20 16:33:02 +00:00
// embeddings
2023-08-18 19:23:14 +00:00
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
// audio
2023-08-18 19:23:14 +00:00
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
2023-05-09 09:43:50 +00:00
// images
2023-08-18 19:23:14 +00:00
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
}
if options.AudioDir != "" {
app.Static("/generated-audio", options.AudioDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
2023-08-18 19:23:14 +00:00
// Experimental Backend Statistics Module
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
2023-04-20 16:33:02 +00:00
2023-08-18 19:23:14 +00:00
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/metrics", metrics.MetricsHandler())
return app, nil
2023-04-11 21:43:43 +00:00
}