LocalAI/pkg/gallery/models.go
Ettore Di Giacinto 78f3c3da48
refactor: consolidate usage of GetURI (#674)
Signed-off-by: mudler <mudler@localai.io>
2023-06-26 12:25:38 +02:00

318 lines
8.6 KiB
Go

package gallery
import (
"crypto/sha256"
"fmt"
"hash"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/imdario/mergo"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v2"
)
/*
description: |
foo
license: ""
urls:
-
-
name: "bar"
config_file: |
# Note, name will be injected. or generated by the alias wanted by the user
threads: 14
files:
- filename: ""
sha: ""
uri: ""
prompt_templates:
- name: ""
content: ""
*/
// Config is the model configuration which contains all the model details
// This configuration is read from the gallery endpoint and is used to download and install the model
type Config struct {
Description string `yaml:"description"`
License string `yaml:"license"`
URLs []string `yaml:"urls"`
Name string `yaml:"name"`
ConfigFile string `yaml:"config_file"`
Files []File `yaml:"files"`
PromptTemplates []PromptTemplate `yaml:"prompt_templates"`
}
type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI string `yaml:"uri" json:"uri"`
}
type PromptTemplate struct {
Name string `yaml:"name"`
Content string `yaml:"content"`
}
func GetGalleryConfigFromURL(url string) (Config, error) {
var config Config
err := utils.GetURI(url, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
return config, err
}
return config, nil
}
func ReadConfigFile(filePath string) (*Config, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read YAML file: %v", err)
}
// Unmarshal YAML data into a Config struct
var config Config
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err)
}
return &config, nil
}
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0755)
if err != nil {
return fmt.Errorf("failed to create base path: %v", err)
}
if len(configOverrides) > 0 {
log.Debug().Msgf("Config overrides %+v", configOverrides)
}
// Download files and verify their SHA
for _, file := range config.Files {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(basePath, file.Filename)
// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
// File exists, check SHA
if file.SHA256 != "" {
// Verify SHA
calculatedSHA, err := calculateSHA(filePath)
if err != nil {
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
}
if calculatedSHA == file.SHA256 {
// SHA matches, skip downloading
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
continue
}
// SHA doesn't match, delete the file and download again
err = os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
}
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)
} else {
// SHA is missing, skip downloading
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
continue
}
} else if !os.IsNotExist(err) {
// Error occurred while checking file existence
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
}
log.Debug().Msgf("Downloading %q", file.URI)
// Download file
resp, err := http.Get(file.URI)
if err != nil {
return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
}
defer resp.Body.Close()
// Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil {
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
}
// Create and write file content
outFile, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
}
defer outFile.Close()
progress := &progressWriter{
fileName: file.Filename,
total: resp.ContentLength,
hash: sha256.New(),
downloadStatus: downloadStatus,
}
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
}
if file.SHA256 != "" {
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != file.SHA256 {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
}
} else {
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
}
log.Debug().Msgf("File %q downloaded and verified", file.Filename)
if utils.IsArchive(filePath) {
log.Debug().Msgf("File %q is an archive, uncompressing to %s", file.Filename, basePath)
if err := utils.ExtractArchive(filePath, basePath); err != nil {
log.Debug().Msgf("Failed decompressing %q: %s", file.Filename, err.Error())
return err
}
}
}
// Write prompt template contents to separate files
for _, template := range config.PromptTemplates {
if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(basePath, template.Name+".tmpl")
// Create parent directory
err := os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil {
return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err)
}
// Create and write file content
err = os.WriteFile(filePath, []byte(template.Content), 0644)
if err != nil {
return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err)
}
log.Debug().Msgf("Prompt template %q written", template.Name)
}
name := config.Name
if nameOverride != "" {
name = nameOverride
}
if err := utils.VerifyPath(name+".yaml", basePath); err != nil {
return err
}
// write config file
if len(configOverrides) != 0 || len(config.ConfigFile) != 0 {
configFilePath := filepath.Join(basePath, name+".yaml")
// Read and update config file as map[string]interface{}
configMap := make(map[string]interface{})
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap)
if err != nil {
return fmt.Errorf("failed to unmarshal config YAML: %v", err)
}
configMap["name"] = name
if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil {
return err
}
// Write updated config file
updatedConfigYAML, err := yaml.Marshal(configMap)
if err != nil {
return fmt.Errorf("failed to marshal updated config YAML: %v", err)
}
err = os.WriteFile(configFilePath, updatedConfigYAML, 0644)
if err != nil {
return fmt.Errorf("failed to write updated config file: %v", err)
}
log.Debug().Msgf("Written config file %s", configFilePath)
}
return nil
}
type progressWriter struct {
fileName string
total int64
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.hash.Write(p)
pw.written += int64(n)
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
} else {
pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0)
}
return
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func calculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}