feat: Model Gallery Endpoint Refactor / Mutable Galleries Endpoints (#991)

refactor for model gallery endpoints - bundle up resources into a
struct, make galleries mutable with some crud endpoints. This is
groundwork required for making efficient use of the new scraper - while
that PR isn't _quite_ ready yet, the goal is to have more, individually
smaller gallery files. Therefore, rather than requiring a full localai
service restart, these new endpoints have been added to make life
easier.

- Adds endpoints to add, list and remove model galleries at runtime
- Adds these endpoints to the Insomnia config
- Minor fix: loading file urls follows symbolic links now
This commit is contained in:
Dave 2023-09-02 03:00:44 -04:00 committed by GitHub
parent 3d7553317f
commit 005f289632
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 108 additions and 32 deletions

View File

@ -168,10 +168,14 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
}{Version: internal.PrintableVersion()})
})
app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cl, galleryService.C, options.Galleries))
app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService))
app.Get("/models/jobs", auth, localai.GetAllStatusEndpoint(galleryService))
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
// openAI compatible API endpoint

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"slices"
"strings"
"sync"
@ -51,7 +52,6 @@ func NewGalleryService(modelPath string) *galleryApplier {
}
}
// prepareModel applies a
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error {
config, err := gallery.GetGalleryConfigFromURL(req.URL)
@ -184,24 +184,12 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
return processRequests(modelPath, s, cm, galleries, requests)
}
/// Endpoints
/// Endpoint Service
func GetOpStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := g.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func GetAllStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(g.getAllStatus())
}
type ModelGalleryService struct {
galleries []gallery.Gallery
modelPath string
galleryApplier *galleryApplier
}
type GalleryModel struct {
@ -209,7 +197,31 @@ type GalleryModel struct {
gallery.GalleryModel
}
func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error {
func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService {
return ModelGalleryService{
galleries: galleries,
modelPath: modelPath,
galleryApplier: galleryApplier,
}
}
func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.getStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.getAllStatus())
}
}
func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
@ -221,11 +233,11 @@ func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan
if err != nil {
return err
}
g <- galleryOp{
mgs.galleryApplier.C <- galleryOp{
req: input.GalleryModel,
id: uuid.String(),
galleryName: input.ID,
galleries: galleries,
galleries: mgs.galleries,
}
return c.JSON(struct {
ID string `json:"uuid"`
@ -234,11 +246,11 @@ func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan
}
}
func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error {
func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", galleries)
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
models, err := gallery.AvailableGalleryModels(galleries, basePath)
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
if err != nil {
return err
}
@ -253,3 +265,56 @@ func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string)
return c.Send(dat)
}
}
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s already exists", input.Name)
}
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
log.Debug().Msgf("Adding %+v to gallery list", *input)
mgs.galleries = append(mgs.galleries, *input)
return c.Send(dat)
}
}
func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(gallery.Gallery)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
}) {
return fmt.Errorf("%s is not currently registered", input.Name)
}
mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool {
return gallery.Name == input.Name
})
return c.Send(nil)
}
}

View File

@ -99,6 +99,7 @@ func WithStringGalleries(galls string) AppOption {
return func(o *Option) {
if galls == "" {
log.Debug().Msgf("no galleries to load")
o.Galleries = []gallery.Gallery{}
return
}
var galleries []gallery.Gallery

File diff suppressed because one or more lines are too long

View File

@ -98,8 +98,8 @@ func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath st
}
// List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
// Models galleries are a list of yaml files that are hosted on a remote server (for example github).
// Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting.
func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryModel, error) {
var models []*GalleryModel

View File

@ -5,6 +5,7 @@ import (
"io"
"net/http"
"os"
"path/filepath"
"strings"
)
@ -32,8 +33,13 @@ func GetURI(url string, f func(url string, i []byte) error) error {
if strings.HasPrefix(url, "file://") {
rawURL := strings.TrimPrefix(url, "file://")
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
resolvedFile, err := filepath.EvalSymlinks(rawURL)
if err != nil {
return err
}
// Read the response body
body, err := os.ReadFile(rawURL)
body, err := os.ReadFile(resolvedFile)
if err != nil {
return err
}