package api import ( "encoding/json" "errors" "fmt" "os" "strings" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/localai" "github.com/go-skynet/LocalAI/api/openai" "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/schema" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/assets" "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/startup" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/recover" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { options := options.NewOptions(opts...) zerolog.SetGlobalLevel(zerolog.InfoLevel) if options.Debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) } log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) startup.PreloadModelsConfigurations(options.Loader.ModelPath, options.ModelsURL...) cl := config.NewConfigLoader() if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) } if options.ConfigFile != "" { if err := cl.LoadConfigFile(options.ConfigFile); err != nil { log.Error().Msgf("error loading config file: %s", err.Error()) } } if err := cl.Preload(options.Loader.ModelPath); err != nil { log.Error().Msgf("error downloading models: %s", err.Error()) } if options.PreloadJSONModels != "" { if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { return nil, nil, err } } if options.PreloadModelsFromPath != "" { if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { return nil, nil, err } } if options.Debug { for _, v := range cl.ListConfigs() { cfg, _ := cl.GetConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } if options.AssetsDestination != "" { // Extract files from the embedded FS err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) if err != nil { log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) } } // turn off any process that was started by GRPC if the context is canceled go func() { <-options.Context.Done() log.Debug().Msgf("Context canceled, shutting down") options.Loader.StopAllGRPC() }() if options.WatchDog { wd := model.NewWatchDog( options.Loader, options.WatchDogBusyTimeout, options.WatchDogIdleTimeout, options.WatchDogBusy, options.WatchDogIdle) options.Loader.SetWatchDog(wd) go wd.Run() go func() { <-options.Context.Done() log.Debug().Msgf("Context canceled, shutting down") wd.Shutdown() }() } return options, cl, nil } func App(opts ...options.AppOption) (*fiber.App, error) { options, cl, err := Startup(opts...) if err != nil { return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error()) } // Return errors as JSON responses app := fiber.New(fiber.Config{ BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB DisableStartupMessage: options.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 code := fiber.StatusInternalServerError // Retrieve the custom status code if it's a *fiber.Error var e *fiber.Error if errors.As(err, &e) { code = e.Code } // Send custom error page return ctx.Status(code).JSON( schema.ErrorResponse{ Error: &schema.APIError{Message: err.Error(), Code: code}, }, ) }, }) if options.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) } // Default middleware config app.Use(recover.New()) if options.Metrics != nil { app.Use(metrics.APIMiddleware(options.Metrics)) } // Auth middleware checking if API key is valid. If no API key is set, no auth is required. auth := func(c *fiber.Ctx) error { if len(options.ApiKeys) == 0 { return c.Next() } // Check for api_keys.json file fileContent, err := os.ReadFile("api_keys.json") if err == nil { // Parse JSON content from the file var fileKeys []string err := json.Unmarshal(fileContent, &fileKeys) if err != nil { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) } // Add file keys to options.ApiKeys options.ApiKeys = append(options.ApiKeys, fileKeys...) } if len(options.ApiKeys) == 0 { return c.Next() } authHeader := c.Get("Authorization") if authHeader == "" { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) } authHeaderParts := strings.Split(authHeader, " ") if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) } apiKey := authHeaderParts[1] for _, key := range options.ApiKeys { if apiKey == key { return c.Next() } } return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) } if options.CORS { var c func(ctx *fiber.Ctx) error if options.CORSAllowOrigins == "" { c = cors.New() } else { c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) } app.Use(c) } // LocalAI API endpoints galleryService := localai.NewGalleryService(options.Loader.ModelPath) galleryService.Start(options.Context, cl) app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { Version string `json:"version"` }{Version: internal.PrintableVersion()}) }) modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) // openAI compatible API endpoint // chat app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) // edit app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) app.Post("/edits", auth, openai.EditEndpoint(cl, options)) // completion app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) // embeddings app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) // audio app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) // images app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) if options.ImageDir != "" { app.Static("/generated-images", options.ImageDir) } if options.AudioDir != "" { app.Static("/generated-audio", options.AudioDir) } ok := func(c *fiber.Ctx) error { return c.SendStatus(200) } // Kubernetes health checks app.Get("/healthz", ok) app.Get("/readyz", ok) // Experimental Backend Statistics Module backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) // models app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) app.Get("/metrics", metrics.MetricsHandler()) return app, nil }