LocalAI/api/api.go

143 lines
3.7 KiB
Go
Raw Normal View History

2023-04-11 22:02:39 +00:00
package api
2023-04-11 21:43:43 +00:00
import (
2023-04-20 16:33:02 +00:00
"errors"
2023-04-11 21:43:43 +00:00
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
2023-05-05 13:53:57 +00:00
"github.com/gofiber/fiber/v2/middleware/logger"
2023-04-11 21:43:43 +00:00
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/rs/zerolog"
2023-04-20 16:33:02 +00:00
"github.com/rs/zerolog/log"
2023-04-11 21:43:43 +00:00
)
func App(opts ...AppOption) (*fiber.App, error) {
2023-05-21 12:38:25 +00:00
options := newOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
2023-05-21 12:38:25 +00:00
if options.debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
2023-04-20 16:33:02 +00:00
// Return errors as JSON responses
app := fiber.New(fiber.Config{
2023-05-21 12:38:25 +00:00
BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: options.disableMessage,
2023-04-20 16:33:02 +00:00
// Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
return ctx.Status(code).JSON(
ErrorResponse{
Error: &APIError{Message: err.Error(), Code: code},
},
)
2023-04-20 16:33:02 +00:00
},
2023-04-11 21:43:43 +00:00
})
2023-05-21 12:38:25 +00:00
if options.debug {
2023-05-05 13:53:57 +00:00
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
}
cm := NewConfigMerger()
2023-05-21 12:38:25 +00:00
if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
2023-05-21 12:38:25 +00:00
if options.configFile != "" {
if err := cm.LoadConfigFile(options.configFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
2023-05-21 12:38:25 +00:00
if options.debug {
for _, v := range cm.ListConfigs() {
cfg, _ := cm.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
2023-04-20 16:33:02 +00:00
// Default middleware config
app.Use(recover.New())
2023-05-21 12:38:25 +00:00
if options.preloadJSONModels != "" {
if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm); err != nil {
return nil, err
}
}
if options.preloadModelsFromPath != "" {
if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm); err != nil {
return nil, err
}
}
2023-05-21 12:38:25 +00:00
if options.cors {
if options.corsAllowOrigins == "" {
app.Use(cors.New())
} else {
app.Use(cors.New(cors.Config{
AllowOrigins: options.corsAllowOrigins,
}))
}
}
2023-04-20 16:33:02 +00:00
// LocalAI API endpoints
2023-05-21 12:38:25 +00:00
applier := newGalleryApplier(options.loader.ModelPath)
applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C))
app.Get("/models/jobs/:uuid", getOpStatus(applier))
2023-04-20 16:33:02 +00:00
// openAI compatible API endpoint
// chat
2023-05-21 12:38:25 +00:00
app.Post("/v1/chat/completions", chatEndpoint(cm, options))
app.Post("/chat/completions", chatEndpoint(cm, options))
2023-04-20 16:33:02 +00:00
// edit
2023-05-21 12:38:25 +00:00
app.Post("/v1/edits", editEndpoint(cm, options))
app.Post("/edits", editEndpoint(cm, options))
2023-04-29 07:22:09 +00:00
// completion
2023-05-21 12:38:25 +00:00
app.Post("/v1/completions", completionEndpoint(cm, options))
app.Post("/completions", completionEndpoint(cm, options))
2023-04-20 16:33:02 +00:00
// embeddings
2023-05-21 12:38:25 +00:00
app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
app.Post("/embeddings", embeddingsEndpoint(cm, options))
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options))
// audio
2023-05-21 12:38:25 +00:00
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options))
2023-05-09 09:43:50 +00:00
// images
2023-05-21 12:38:25 +00:00
app.Post("/v1/images/generations", imageEndpoint(cm, options))
2023-05-21 12:38:25 +00:00
if options.imageDir != "" {
app.Static("/generated-images", options.imageDir)
}
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}
// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)
// models
2023-05-21 12:38:25 +00:00
app.Get("/v1/models", listModels(options.loader, cm))
app.Get("/models", listModels(options.loader, cm))
2023-04-20 16:33:02 +00:00
return app, nil
2023-04-11 21:43:43 +00:00
}