feat: add /models/apply endpoint to prepare models (#286)

This commit is contained in:
Ettore Di Giacinto 2023-05-18 15:59:03 +02:00 committed by GitHub
parent 5617e50ebc
commit cc9aa9eb3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 556 additions and 33 deletions

View File

@ -211,11 +211,11 @@ test-models/testmodel:
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
wget https://huggingface.co/imxcstar/rwkv-4-raven-ggml/resolve/main/RWKV-4-Raven-1B5-v11-Eng99%25-Other1%25-20230425-ctx4096-16_Q4_2.bin -O test-models/rwkv
wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json
cp tests/fixtures/* test-models
cp tests/models_fixtures/* test-models
test: prepare test-models/testmodel
cp tests/fixtures/* test-models
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api
cp tests/models_fixtures/* test-models
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api ./pkg
## Help:
help: ## Show this help.

View File

@ -11,11 +11,12 @@
**LocalAI** is a drop-in replacement REST API compatible with OpenAI API specifications for local inferencing. It allows to run models locally or on-prem with consumer grade hardware, supporting multiple models families compatible with the `ggml` format. For a list of the supported model families, see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table).
- OpenAI drop-in alternative REST API
- Local, OpenAI drop-in alternative REST API. You own your data.
- Supports multiple models, Audio transcription, Text generation with GPTs, Image generation with stable diffusion (experimental)
- Once loaded the first time, it keep models loaded in memory for faster inference
- Support for prompt templates
- Doesn't shell-out, but uses C++ bindings for a faster inference and better performance.
- NO GPU required. NO Internet access is required either. Optional, GPU Acceleration is available in `llama.cpp`-compatible LLMs. [See building instructions](https://github.com/go-skynet/LocalAI#cublas).
LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).
@ -434,7 +435,7 @@ local-ai --models-path <model_path> [--address <address>] [--threads <num_thread
| debug | DEBUG | false | Enable debug mode. |
| config-file | CONFIG_FILE | empty | Path to a LocalAI config file. |
| upload_limit | UPLOAD_LIMIT | 5MB | Upload limit for whisper. |
| image-dir | CONFIG_FILE | empty | Image directory to store and serve processed images. |
| image-path | IMAGE_PATH | empty | Image directory to store and serve processed images. |
</details>
@ -567,6 +568,8 @@ Note: CuBLAS support is experimental, and has not been tested on real HW. please
make BUILD_TYPE=cublas build
```
More informations available in the upstream PR: https://github.com/ggerganov/llama.cpp/pull/1412
</details>
### Windows compatibility

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
model "github.com/go-skynet/LocalAI/pkg/model"
@ -12,7 +13,7 @@ import (
"github.com/rs/zerolog/log"
)
func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
@ -48,7 +49,7 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
}))
}
cm := make(ConfigMerger)
cm := NewConfigMerger()
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
@ -60,39 +61,51 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
}
if debug {
for k, v := range cm {
log.Debug().Msgf("Model: %s (config: %+v)", k, v)
for _, v := range cm.ListConfigs() {
cfg, _ := cm.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}
// Default middleware config
app.Use(recover.New())
app.Use(cors.New())
// LocalAI API endpoints
applier := newGalleryApplier(loader.ModelPath)
applier.start(c, cm)
app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
app.Get("/models/jobs/:uid", getOpStatus(applier))
// openAI compatible API endpoint
// chat
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
// edit
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
// completion
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
// embeddings
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
// /v1/engines/{engine_id}/embeddings
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
// audio
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
// images
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
if imageDir != "" {
app.Static("/generated-images", imageDir)
}
// models
app.Get("/v1/models", listModels(loader, cm))
app.Get("/models", listModels(loader, cm))

View File

@ -22,10 +22,14 @@ var _ = Describe("API test", func() {
var modelLoader *model.ModelLoader
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
var cancel context.CancelFunc
Context("API query", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App("", modelLoader, 15, 1, 512, false, true, true, "")
c, cancel = context.WithCancel(context.Background())
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -42,6 +46,7 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("returns the models list", func() {
@ -140,7 +145,9 @@ var _ = Describe("API test", func() {
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
c, cancel = context.WithCancel(context.Background())
app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
@ -155,10 +162,10 @@ var _ = Describe("API test", func() {
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
})
It("can generate chat completions from config file", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(12))

View File

@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
@ -43,8 +44,16 @@ type TemplateConfig struct {
Edit string `yaml:"edit"`
}
type ConfigMerger map[string]Config
type ConfigMerger struct {
configs map[string]Config
sync.Mutex
}
func NewConfigMerger() *ConfigMerger {
return &ConfigMerger{
configs: make(map[string]Config),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file)
@ -72,28 +81,51 @@ func ReadConfig(file string) (*Config, error) {
}
func (cm ConfigMerger) LoadConfigFile(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cm[cc.Name] = *cc
cm.configs[cc.Name] = *cc
}
return nil
}
func (cm ConfigMerger) LoadConfig(file string) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadConfig(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm[c.Name] = *c
cm.configs[c.Name] = *c
return nil
}
func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
cm.Lock()
defer cm.Unlock()
v, exists := cm.configs[m]
return v, exists
}
func (cm ConfigMerger) ListConfigs() []string {
cm.Lock()
defer cm.Unlock()
var res []string
for k := range cm.configs {
res = append(res, k)
}
return res
}
func (cm ConfigMerger) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
files, err := ioutil.ReadDir(path)
if err != nil {
return err
@ -106,7 +138,7 @@ func (cm ConfigMerger) LoadConfigs(path string) error {
}
c, err := ReadConfig(filepath.Join(path, file.Name()))
if err == nil {
cm[c.Name] = *c
cm.configs[c.Name] = *c
}
}
@ -253,7 +285,7 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin
return modelFile, input, nil
}
func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
@ -263,7 +295,7 @@ func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader
}
var config *Config
cfg, exists := cm[modelFile]
cfg, exists := cm.GetConfig(modelFile)
if !exists {
config = &Config{
OpenAIRequest: defaultRequest(modelFile),

146
api/gallery.go Normal file
View File

@ -0,0 +1,146 @@
package api
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"sync"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"gopkg.in/yaml.v3"
)
type galleryOp struct {
req ApplyGalleryModelRequest
id string
}
type galleryOpStatus struct {
Error error `json:"error"`
Processed bool `json:"processed"`
Message string `json:"message"`
}
type galleryApplier struct {
modelPath string
sync.Mutex
C chan galleryOp
statuses map[string]*galleryOpStatus
}
func newGalleryApplier(modelPath string) *galleryApplier {
return &galleryApplier{
modelPath: modelPath,
C: make(chan galleryOp),
statuses: make(map[string]*galleryOpStatus),
}
}
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) {
g.Lock()
defer g.Unlock()
g.statuses[s] = op
}
func (g *galleryApplier) getstatus(s string) *galleryOpStatus {
g.Lock()
defer g.Unlock()
return g.statuses[s]
}
func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) {
go func() {
for {
select {
case <-c.Done():
return
case op := <-g.C:
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"})
updateError := func(e error) {
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
}
// Send a GET request to the URL
response, err := http.Get(op.req.URL)
if err != nil {
updateError(err)
continue
}
defer response.Body.Close()
// Read the response body
body, err := ioutil.ReadAll(response.Body)
if err != nil {
updateError(err)
continue
}
// Unmarshal YAML data into a Config struct
var config gallery.Config
err = yaml.Unmarshal(body, &config)
if err != nil {
updateError(fmt.Errorf("failed to unmarshal YAML: %v", err))
continue
}
if err := gallery.Apply(g.modelPath, op.req.Name, &config); err != nil {
updateError(err)
continue
}
// Reload models
if err := cm.LoadConfigs(g.modelPath); err != nil {
updateError(err)
continue
}
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"})
}
}
}()
}
// endpoints
type ApplyGalleryModelRequest struct {
URL string `json:"url"`
Name string `json:"name"`
}
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := g.getstatus(c.Params("uid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(ApplyGalleryModelRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
g <- galleryOp{
req: *input,
id: uuid.String(),
}
return c.JSON(struct {
ID string `json:"uid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}

View File

@ -142,7 +142,7 @@ func defaultRequest(modelFile string) OpenAIRequest {
}
// https://platform.openai.com/docs/api-reference/completions
func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
@ -199,7 +199,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
// https://platform.openai.com/docs/api-reference/embeddings
func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
if err != nil {
@ -256,7 +256,7 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
@ -378,7 +378,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
}
}
func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true)
if err != nil {
@ -449,7 +449,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
*
*/
func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
if err != nil {
@ -574,7 +574,7 @@ func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, image
}
// https://platform.openai.com/docs/api-reference/audio/create
func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false)
if err != nil {
@ -641,7 +641,7 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error {
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
@ -655,7 +655,7 @@ func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx)
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
}
for k := range cm {
for _, k := range cm.ListConfigs() {
if _, exists := mm[k]; !exists {
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
}

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"os"
"path/filepath"
@ -57,9 +58,9 @@ func main() {
Value: ":8080",
},
&cli.StringFlag{
Name: "image-dir",
Name: "image-path",
DefaultText: "Image directory",
EnvVars: []string{"IMAGE_DIR"},
EnvVars: []string{"IMAGE_PATH"},
Value: "",
},
&cli.IntFlag{
@ -93,7 +94,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Copyright: "go-skynet authors",
Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-dir")).Listen(ctx.String("address"))
return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address"))
},
}

View File

@ -0,0 +1,13 @@
package gallery_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestGallery(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Gallery test suite")
}

237
pkg/gallery/models.go Normal file
View File

@ -0,0 +1,237 @@
package gallery
import (
"crypto/sha256"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
)
/*
description: |
foo
license: ""
urls:
-
-
name: "bar"
config_file: |
# Note, name will be injected. or generated by the alias wanted by the user
threads: 14
files:
- filename: ""
sha: ""
uri: ""
prompt_templates:
- name: ""
content: ""
*/
type Config struct {
Description string `yaml:"description"`
License string `yaml:"license"`
URLs []string `yaml:"urls"`
Name string `yaml:"name"`
ConfigFile string `yaml:"config_file"`
Files []File `yaml:"files"`
PromptTemplates []PromptTemplate `yaml:"prompt_templates"`
}
type File struct {
Filename string `yaml:"filename"`
SHA256 string `yaml:"sha256"`
URI string `yaml:"uri"`
}
type PromptTemplate struct {
Name string `yaml:"name"`
Content string `yaml:"content"`
}
func ReadConfigFile(filePath string) (*Config, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read YAML file: %v", err)
}
// Unmarshal YAML data into a Config struct
var config Config
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
}
return &config, nil
}
func Apply(basePath, nameOverride string, config *Config) error {
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755)
if err != nil {
return fmt.Errorf("failed to create base path: %v", err)
}
// Download files and verify their SHA
for _, file := range config.Files {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
// Create file path
filePath := filepath.Join(basePath, file.Filename)
// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
// File exists, check SHA
if file.SHA256 != "" {
// Verify SHA
calculatedSHA, err := calculateSHA(filePath)
if err != nil {
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
}
if calculatedSHA == file.SHA256 {
// SHA matches, skip downloading
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
continue
}
// SHA doesn't match, delete the file and download again
err = os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
}
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)
} else {
// SHA is missing, skip downloading
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
continue
}
} else if !os.IsNotExist(err) {
// Error occurred while checking file existence
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
}
log.Debug().Msgf("Downloading %q", file.URI)
// Download file
resp, err := http.Get(file.URI)
if err != nil {
return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
}
defer resp.Body.Close()
// Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil {
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
}
// Create and write file content
outFile, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
}
defer outFile.Close()
if file.SHA256 != "" {
log.Debug().Msgf("Download and verifying %q", file.Filename)
// Write file content and calculate SHA
hash := sha256.New()
_, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
}
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil))
if calculatedSHA != file.SHA256 {
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
}
} else {
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
_, err = io.Copy(outFile, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
}
}
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
}
// Write prompt template contents to separate files
for _, template := range config.PromptTemplates {
// Create file path
filePath := filepath.Join(basePath, template.Name+".tmpl")
// Create parent directory
err := os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil {
return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
}
// Create and write file content
err = os.WriteFile(filePath, []byte(template.Content), 0644)
if err != nil {
return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
}
log.Debug().Msgf("Prompt template %q written", template.Name)
}
name := config.Name
if nameOverride != "" {
name = nameOverride
}
configFilePath := filepath.Join(basePath, name+".yaml")
// Read and update config file as map[string]interface{}
configMap := make(map[string]interface{})
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
if err != nil {
return fmt.Errorf("failed to unmarshal config YAML: %v", err)
}
configMap["name"] = name
// Write updated config file
updatedConfigYAML, err := yaml.Marshal(configMap)
if err != nil {
return fmt.Errorf("failed to marshal updated config YAML: %v", err)
}
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
if err != nil {
return fmt.Errorf("failed to write updated config file: %v", err)
}
log.Debug().Msgf("Written config file %s", configFilePath)
return nil
}
func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

View File

@ -0,0 +1,30 @@
package gallery_test
import (
"os"
"path/filepath"
. "github.com/go-skynet/LocalAI/pkg/gallery"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Model test", func() {
Context("Downloading", func() {
It("applies model correctly", func() {
tempdir, err := os.MkdirTemp("", "test")
Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(tempdir)
c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
Expect(err).ToNot(HaveOccurred())
err = Apply(tempdir, "", c)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
_, err = os.Stat(filepath.Join(tempdir, f))
Expect(err).ToNot(HaveOccurred())
}
})
})
})

View File

@ -164,11 +164,12 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
}
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (interface{}, error) {
log.Debug().Msgf("Loading models greedly")
log.Debug().Msgf("Loading model '%s' greedly", modelFile)
ml.mu.Lock()
m, exists := ml.models[modelFile]
if exists {
log.Debug().Msgf("Model '%s' already loaded", modelFile)
ml.mu.Unlock()
return m, nil
}

40
tests/fixtures/gallery_simple.yaml vendored Normal file
View File

@ -0,0 +1,40 @@
name: "cerebras"
description: |
cerebras
license: "Apache 2.0"
config_file: |
parameters:
model: cerebras
top_k: 80
temperature: 0.2
top_p: 0.7
context_size: 1024
stopwords:
- "HUMAN:"
- "GPT:"
roles:
user: ""
system: ""
template:
completion: "cerebras-completion"
chat: cerebras-chat
files:
- filename: "cerebras"
sha256: "c947051ae4dba9530ca55d923a7a484acd65664c8633462c8ccd4bb7848f2c65"
uri: "https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerebras-111m-q4_2.bin"
prompt_templates:
- name: "cerebras-completion"
content: |
Complete the prompt
### Prompt:
{{.Input}}
### Response:
- name: "cerebras-chat"
content: |
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt:
{{.Input}}
### Response: