From 563c5b7ea04b03b87af0d05c0ac4e507df9a9cd3 Mon Sep 17 00:00:00 2001 From: lunamidori5 <118759930+lunamidori5@users.noreply.github.com> Date: Mon, 4 Dec 2023 19:06:45 -0800 Subject: [PATCH] Added Check API KEYs file to API.go (#1381) Added API KEYs file Signed-off-by: lunamidori5 <118759930+lunamidori5@users.noreply.github.com> --- api/api.go | 54 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/api/api.go b/api/api.go index 9a097838..4970623e 100644 --- a/api/api.go +++ b/api/api.go @@ -3,6 +3,8 @@ package api import ( "errors" "fmt" + "encoding/json" + "io/ioutil" "strings" config "github.com/go-skynet/LocalAI/api/config" @@ -144,30 +146,48 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // Auth middleware checking if API key is valid. If no API key is set, no auth is required. auth := func(c *fiber.Ctx) error { - if len(options.ApiKeys) > 0 { - authHeader := c.Get("Authorization") - if authHeader == "" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) - } - authHeaderParts := strings.Split(authHeader, " ") - if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + if len(options.ApiKeys) == 0 { + return c.Next() + } + + // Check for api_keys.json file + fileContent, err := ioutil.ReadFile("api_keys.json") + if err == nil { + // Parse JSON content from the file + var fileKeys []string + err := json.Unmarshal(fileContent, &fileKeys) + if err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) } - apiKey := authHeaderParts[1] - validApiKey := false - for _, key := range options.ApiKeys { - if apiKey == key { - validApiKey = true - } - } - if !validApiKey { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + // Add file keys to options.ApiKeys + options.ApiKeys = append(options.ApiKeys, fileKeys...) + } + + authHeader := c.Get("Authorization") + if authHeader == "" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) + } + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + } + + apiKey := authHeaderParts[1] + validApiKey := false + for _, key := range options.ApiKeys { + if apiKey == key { + validApiKey = true } } + if !validApiKey { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + } + return c.Next() } + if options.CORS { var c func(ctx *fiber.Ctx) error if options.CORSAllowOrigins == "" {