LocalAI/pkg/schema/config.go
Dave ab7b4d5ee9
[Refactor]: Core/API Split (#1506)
Refactors api folder to core, creates firm split between backend code and api frontend.
2024-01-05 15:34:56 +01:00

401 lines
10 KiB
Go

package schema
import (
"encoding/json"
"fmt"
"os"
"github.com/go-skynet/LocalAI/pkg/utils"
"gopkg.in/yaml.v3"
)
type Config struct {
PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
Debug bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`
FunctionsConfig Functions `yaml:"function"`
FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline"`
// AutoGPTQ specifics
AutoGPTQ AutoGPTQ `yaml:"autogptq"`
// Diffusers
Diffusers Diffusers `yaml:"diffusers"`
Step int `yaml:"step"`
// GRPC Options
GRPC GRPC `yaml:"grpc"`
// Vall-e-x
VallE VallE `yaml:"vall-e"`
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda"`
DownloadFiles []File `yaml:"download_files"`
}
type File struct {
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI string `yaml:"uri" json:"uri"`
}
type VallE struct {
AudioPath string `yaml:"audio_path"`
}
type FeatureFlag map[string]*bool
func (ff FeatureFlag) Enabled(s string) bool {
v, exist := ff[s]
return exist && v != nil && *v
}
type GRPC struct {
Attempts int `yaml:"attempts"`
AttemptsSleepTime int `yaml:"attempts_sleep_time"`
}
type Diffusers struct {
CUDA bool `yaml:"cuda"`
PipelineType string `yaml:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale
IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net"`
}
type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt"`
TensorSplit string `yaml:"tensor_split"`
MainGPU string `yaml:"main_gpu"`
RMSNormEps float32 `yaml:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa"`
PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro"`
MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"`
MMap bool `yaml:"mmap"`
MMlock bool `yaml:"mmlock"`
LowVRAM bool `yaml:"low_vram"`
Grammar string `yaml:"grammar"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix"`
ContextSize int `yaml:"context_size"`
NUMA bool `yaml:"numa"`
LoraAdapter string `yaml:"lora_adapter"`
LoraBase string `yaml:"lora_base"`
LoraScale float32 `yaml:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq"`
DraftModel string `yaml:"draft_model"`
NDraft int32 `yaml:"n_draft"`
Quantization string `yaml:"quantization"`
MMProj string `yaml:"mmproj"`
RopeScaling string `yaml:"rope_scaling"`
YarnExtFactor float32 `yaml:"yarn_ext_factor"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow"`
}
type AutoGPTQ struct {
ModelBaseName string `yaml:"model_base_name"`
Device string `yaml:"device"`
Triton bool `yaml:"triton"`
UseFastTokenizer bool `yaml:"use_fast_tokenizer"`
}
type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
}
type TemplateConfig struct {
Chat string `yaml:"chat"`
ChatMessage string `yaml:"chat_message"`
Completion string `yaml:"completion"`
Edit string `yaml:"edit"`
Functions string `yaml:"function"`
}
func (c *Config) SetFunctionCallString(s string) {
c.functionCallString = s
}
func (c *Config) SetFunctionCallNameString(s string) {
c.functionCallNameString = s
}
func (c *Config) ShouldUseFunctions() bool {
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction())
}
func (c *Config) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}
func (c *Config) FunctionToCall() string {
return c.functionCallNameString
}
func defaultPredictOptions(modelFile string) PredictionOptions {
return PredictionOptions{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
}
}
func DefaultConfig(modelFile string) *Config {
return &Config{
PredictionOptions: defaultPredictOptions(modelFile),
}
}
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return *c, nil
}
func ReadSingleConfigFile(file string) (*Config, error) {
c := &Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return c, nil
}
func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != 0 {
config.TopK = input.TopK
}
if input.TopP != 0 {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.ModelBaseName != "" {
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.UseFastTokenizer {
config.UseFastTokenizer = input.UseFastTokenizer
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
config.Maxtokens = input.Maxtokens
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
config.Seed = input.Seed
}
if input.Mirostat != 0 {
config.LLMConfig.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.LLMConfig.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.LLMConfig.MirostatTAU = input.MirostatTAU
}
if input.TypicalP != 0 {
config.TypicalP = input.TypicalP
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
// Decode each request's message content
index := 0
for i, m := range input.Messages {
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
input.Messages[i].StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := utils.GetBase64Image(pp.ImageURL.URL)
if err == nil {
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
// TODO: check that this was merged correctly? I _think_ it is?
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []interface{}:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
}