package schema import ( "encoding/json" "fmt" "os" "github.com/go-skynet/LocalAI/pkg/utils" "gopkg.in/yaml.v3" ) type Config struct { PredictionOptions `yaml:"parameters"` Name string `yaml:"name"` F16 bool `yaml:"f16"` Threads int `yaml:"threads"` Debug bool `yaml:"debug"` Roles map[string]string `yaml:"roles"` Embeddings bool `yaml:"embeddings"` Backend string `yaml:"backend"` TemplateConfig TemplateConfig `yaml:"template"` PromptStrings, InputStrings []string `yaml:"-"` InputToken [][]int `yaml:"-"` functionCallString, functionCallNameString string `yaml:"-"` FunctionsConfig Functions `yaml:"function"` FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. // LLM configs (GPT4ALL, Llama.cpp, ...) LLMConfig `yaml:",inline"` // AutoGPTQ specifics AutoGPTQ AutoGPTQ `yaml:"autogptq"` // Diffusers Diffusers Diffusers `yaml:"diffusers"` Step int `yaml:"step"` // GRPC Options GRPC GRPC `yaml:"grpc"` // Vall-e-x VallE VallE `yaml:"vall-e"` // CUDA // Explicitly enable CUDA or not (some backends might need it) CUDA bool `yaml:"cuda"` DownloadFiles []File `yaml:"download_files"` } type File struct { Filename string `yaml:"filename" json:"filename"` SHA256 string `yaml:"sha256" json:"sha256"` URI string `yaml:"uri" json:"uri"` } type VallE struct { AudioPath string `yaml:"audio_path"` } type FeatureFlag map[string]*bool func (ff FeatureFlag) Enabled(s string) bool { v, exist := ff[s] return exist && v != nil && *v } type GRPC struct { Attempts int `yaml:"attempts"` AttemptsSleepTime int `yaml:"attempts_sleep_time"` } type Diffusers struct { CUDA bool `yaml:"cuda"` PipelineType string `yaml:"pipeline_type"` SchedulerType string `yaml:"scheduler_type"` EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser ClipSkip int `yaml:"clip_skip"` // Skip every N frames ClipModel string `yaml:"clip_model"` // Clip model to use ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model ControlNet string `yaml:"control_net"` } type LLMConfig struct { SystemPrompt string `yaml:"system_prompt"` TensorSplit string `yaml:"tensor_split"` MainGPU string `yaml:"main_gpu"` RMSNormEps float32 `yaml:"rms_norm_eps"` NGQA int32 `yaml:"ngqa"` PromptCachePath string `yaml:"prompt_cache_path"` PromptCacheAll bool `yaml:"prompt_cache_all"` PromptCacheRO bool `yaml:"prompt_cache_ro"` MirostatETA float64 `yaml:"mirostat_eta"` MirostatTAU float64 `yaml:"mirostat_tau"` Mirostat int `yaml:"mirostat"` NGPULayers int `yaml:"gpu_layers"` MMap bool `yaml:"mmap"` MMlock bool `yaml:"mmlock"` LowVRAM bool `yaml:"low_vram"` Grammar string `yaml:"grammar"` StopWords []string `yaml:"stopwords"` Cutstrings []string `yaml:"cutstrings"` TrimSpace []string `yaml:"trimspace"` TrimSuffix []string `yaml:"trimsuffix"` ContextSize int `yaml:"context_size"` NUMA bool `yaml:"numa"` LoraAdapter string `yaml:"lora_adapter"` LoraBase string `yaml:"lora_base"` LoraScale float32 `yaml:"lora_scale"` NoMulMatQ bool `yaml:"no_mulmatq"` DraftModel string `yaml:"draft_model"` NDraft int32 `yaml:"n_draft"` Quantization string `yaml:"quantization"` MMProj string `yaml:"mmproj"` RopeScaling string `yaml:"rope_scaling"` YarnExtFactor float32 `yaml:"yarn_ext_factor"` YarnAttnFactor float32 `yaml:"yarn_attn_factor"` YarnBetaFast float32 `yaml:"yarn_beta_fast"` YarnBetaSlow float32 `yaml:"yarn_beta_slow"` } type AutoGPTQ struct { ModelBaseName string `yaml:"model_base_name"` Device string `yaml:"device"` Triton bool `yaml:"triton"` UseFastTokenizer bool `yaml:"use_fast_tokenizer"` } type Functions struct { DisableNoAction bool `yaml:"disable_no_action"` NoActionFunctionName string `yaml:"no_action_function_name"` NoActionDescriptionName string `yaml:"no_action_description_name"` } type TemplateConfig struct { Chat string `yaml:"chat"` ChatMessage string `yaml:"chat_message"` Completion string `yaml:"completion"` Edit string `yaml:"edit"` Functions string `yaml:"function"` } func (c *Config) SetFunctionCallString(s string) { c.functionCallString = s } func (c *Config) SetFunctionCallNameString(s string) { c.functionCallNameString = s } func (c *Config) ShouldUseFunctions() bool { return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) } func (c *Config) ShouldCallSpecificFunction() bool { return len(c.functionCallNameString) > 0 } func (c *Config) FunctionToCall() string { return c.functionCallNameString } func defaultPredictOptions(modelFile string) PredictionOptions { return PredictionOptions{ TopP: 0.7, TopK: 80, Maxtokens: 512, Temperature: 0.9, Model: modelFile, } } func DefaultConfig(modelFile string) *Config { return &Config{ PredictionOptions: defaultPredictOptions(modelFile), } } func ReadConfigFile(file string) ([]*Config, error) { c := &[]*Config{} f, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("cannot read config file: %w", err) } if err := yaml.Unmarshal(f, c); err != nil { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } return *c, nil } func ReadSingleConfigFile(file string) (*Config, error) { c := &Config{} f, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("cannot read config file: %w", err) } if err := yaml.Unmarshal(f, c); err != nil { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } return c, nil } func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) { if input.Echo { config.Echo = input.Echo } if input.TopK != 0 { config.TopK = input.TopK } if input.TopP != 0 { config.TopP = input.TopP } if input.Backend != "" { config.Backend = input.Backend } if input.ClipSkip != 0 { config.Diffusers.ClipSkip = input.ClipSkip } if input.ModelBaseName != "" { config.AutoGPTQ.ModelBaseName = input.ModelBaseName } if input.NegativePromptScale != 0 { config.NegativePromptScale = input.NegativePromptScale } if input.UseFastTokenizer { config.UseFastTokenizer = input.UseFastTokenizer } if input.NegativePrompt != "" { config.NegativePrompt = input.NegativePrompt } if input.RopeFreqBase != 0 { config.RopeFreqBase = input.RopeFreqBase } if input.RopeFreqScale != 0 { config.RopeFreqScale = input.RopeFreqScale } if input.Grammar != "" { config.Grammar = input.Grammar } if input.Temperature != 0 { config.Temperature = input.Temperature } if input.Maxtokens != 0 { config.Maxtokens = input.Maxtokens } if input.RepeatPenalty != 0 { config.RepeatPenalty = input.RepeatPenalty } if input.Keep != 0 { config.Keep = input.Keep } if input.Batch != 0 { config.Batch = input.Batch } if input.F16 { config.F16 = input.F16 } if input.IgnoreEOS { config.IgnoreEOS = input.IgnoreEOS } if input.Seed != 0 { config.Seed = input.Seed } if input.Mirostat != 0 { config.LLMConfig.Mirostat = input.Mirostat } if input.MirostatETA != 0 { config.LLMConfig.MirostatETA = input.MirostatETA } if input.MirostatTAU != 0 { config.LLMConfig.MirostatTAU = input.MirostatTAU } if input.TypicalP != 0 { config.TypicalP = input.TypicalP } switch stop := input.Stop.(type) { case string: if stop != "" { config.StopWords = append(config.StopWords, stop) } case []interface{}: for _, pp := range stop { if s, ok := pp.(string); ok { config.StopWords = append(config.StopWords, s) } } } // Decode each request's message content index := 0 for i, m := range input.Messages { switch content := m.Content.(type) { case string: input.Messages[i].StringContent = content case []interface{}: dat, _ := json.Marshal(content) c := []Content{} json.Unmarshal(dat, &c) for _, pp := range c { if pp.Type == "text" { input.Messages[i].StringContent = pp.Text } else if pp.Type == "image_url" { // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: base64, err := utils.GetBase64Image(pp.ImageURL.URL) if err == nil { input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff // set a placeholder for each image input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent index++ } else { fmt.Print("Failed encoding image", err) } } } } } // TODO: check that this was merged correctly? I _think_ it is? switch inputs := input.Input.(type) { case string: if inputs != "" { config.InputStrings = append(config.InputStrings, inputs) } case []interface{}: for _, pp := range inputs { switch i := pp.(type) { case string: config.InputStrings = append(config.InputStrings, i) case []interface{}: tokens := []int{} for _, ii := range i { tokens = append(tokens, int(ii.(float64))) } config.InputToken = append(config.InputToken, tokens) } } } // Can be either a string or an object switch fnc := input.FunctionCall.(type) { case string: if fnc != "" { config.SetFunctionCallString(fnc) } case map[string]interface{}: var name string n, exists := fnc["name"] if exists { nn, e := n.(string) if e { name = nn } } config.SetFunctionCallNameString(name) } switch p := input.Prompt.(type) { case string: config.PromptStrings = append(config.PromptStrings, p) case []interface{}: for _, pp := range p { if s, ok := pp.(string); ok { config.PromptStrings = append(config.PromptStrings, s) } } } }