2024-01-05 14:34:56 +00:00
package schema
2023-07-14 23:19:43 +00:00
import (
2024-01-05 14:34:56 +00:00
"encoding/json"
2023-07-14 23:19:43 +00:00
"fmt"
"os"
2023-12-18 17:58:44 +00:00
"github.com/go-skynet/LocalAI/pkg/utils"
2023-07-14 23:19:43 +00:00
"gopkg.in/yaml.v3"
)
type Config struct {
PredictionOptions ` yaml:"parameters" `
2023-08-09 06:38:51 +00:00
Name string ` yaml:"name" `
F16 bool ` yaml:"f16" `
Threads int ` yaml:"threads" `
Debug bool ` yaml:"debug" `
Roles map [ string ] string ` yaml:"roles" `
Embeddings bool ` yaml:"embeddings" `
Backend string ` yaml:"backend" `
TemplateConfig TemplateConfig ` yaml:"template" `
PromptStrings , InputStrings [ ] string ` yaml:"-" `
InputToken [ ] [ ] int ` yaml:"-" `
functionCallString , functionCallNameString string ` yaml:"-" `
2023-07-14 23:19:43 +00:00
FunctionsConfig Functions ` yaml:"function" `
2023-07-22 15:31:39 +00:00
2023-08-19 14:15:22 +00:00
FeatureFlag FeatureFlag ` yaml:"feature_flags" ` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
2023-08-09 06:38:51 +00:00
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig ` yaml:",inline" `
// AutoGPTQ specifics
AutoGPTQ AutoGPTQ ` yaml:"autogptq" `
2023-08-02 22:51:08 +00:00
2023-08-09 06:38:51 +00:00
// Diffusers
Diffusers Diffusers ` yaml:"diffusers" `
2023-12-13 18:20:22 +00:00
Step int ` yaml:"step" `
2023-08-15 23:11:32 +00:00
// GRPC Options
GRPC GRPC ` yaml:"grpc" `
2023-09-04 17:25:23 +00:00
// Vall-e-x
VallE VallE ` yaml:"vall-e" `
2023-12-08 14:45:04 +00:00
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool ` yaml:"cuda" `
2024-01-01 13:39:13 +00:00
DownloadFiles [ ] File ` yaml:"download_files" `
}
type File struct {
Filename string ` yaml:"filename" json:"filename" `
SHA256 string ` yaml:"sha256" json:"sha256" `
URI string ` yaml:"uri" json:"uri" `
2023-09-04 17:25:23 +00:00
}
type VallE struct {
AudioPath string ` yaml:"audio_path" `
2023-08-15 23:11:32 +00:00
}
2023-08-19 14:15:22 +00:00
type FeatureFlag map [ string ] * bool
func ( ff FeatureFlag ) Enabled ( s string ) bool {
v , exist := ff [ s ]
return exist && v != nil && * v
}
2023-08-15 23:11:32 +00:00
type GRPC struct {
Attempts int ` yaml:"attempts" `
AttemptsSleepTime int ` yaml:"attempts_sleep_time" `
2023-08-09 06:38:51 +00:00
}
type Diffusers struct {
2023-12-15 23:06:20 +00:00
CUDA bool ` yaml:"cuda" `
2023-08-15 23:11:42 +00:00
PipelineType string ` yaml:"pipeline_type" `
SchedulerType string ` yaml:"scheduler_type" `
EnableParameters string ` yaml:"enable_parameters" ` // A list of comma separated parameters to specify
CFGScale float32 ` yaml:"cfg_scale" ` // Classifier-Free Guidance Scale
2023-08-17 21:38:59 +00:00
IMG2IMG bool ` yaml:"img2img" ` // Image to Image Diffuser
ClipSkip int ` yaml:"clip_skip" ` // Skip every N frames
ClipModel string ` yaml:"clip_model" ` // Clip model to use
ClipSubFolder string ` yaml:"clip_subfolder" ` // Subfolder to use for clip model
2023-12-13 18:20:22 +00:00
ControlNet string ` yaml:"control_net" `
2023-08-09 06:38:51 +00:00
}
type LLMConfig struct {
SystemPrompt string ` yaml:"system_prompt" `
TensorSplit string ` yaml:"tensor_split" `
MainGPU string ` yaml:"main_gpu" `
RMSNormEps float32 ` yaml:"rms_norm_eps" `
NGQA int32 ` yaml:"ngqa" `
PromptCachePath string ` yaml:"prompt_cache_path" `
PromptCacheAll bool ` yaml:"prompt_cache_all" `
PromptCacheRO bool ` yaml:"prompt_cache_ro" `
MirostatETA float64 ` yaml:"mirostat_eta" `
MirostatTAU float64 ` yaml:"mirostat_tau" `
Mirostat int ` yaml:"mirostat" `
NGPULayers int ` yaml:"gpu_layers" `
MMap bool ` yaml:"mmap" `
MMlock bool ` yaml:"mmlock" `
LowVRAM bool ` yaml:"low_vram" `
Grammar string ` yaml:"grammar" `
StopWords [ ] string ` yaml:"stopwords" `
Cutstrings [ ] string ` yaml:"cutstrings" `
TrimSpace [ ] string ` yaml:"trimspace" `
2024-01-01 13:39:42 +00:00
TrimSuffix [ ] string ` yaml:"trimsuffix" `
ContextSize int ` yaml:"context_size" `
NUMA bool ` yaml:"numa" `
LoraAdapter string ` yaml:"lora_adapter" `
LoraBase string ` yaml:"lora_base" `
LoraScale float32 ` yaml:"lora_scale" `
NoMulMatQ bool ` yaml:"no_mulmatq" `
DraftModel string ` yaml:"draft_model" `
NDraft int32 ` yaml:"n_draft" `
Quantization string ` yaml:"quantization" `
MMProj string ` yaml:"mmproj" `
2023-11-11 17:40:48 +00:00
RopeScaling string ` yaml:"rope_scaling" `
YarnExtFactor float32 ` yaml:"yarn_ext_factor" `
YarnAttnFactor float32 ` yaml:"yarn_attn_factor" `
YarnBetaFast float32 ` yaml:"yarn_beta_fast" `
YarnBetaSlow float32 ` yaml:"yarn_beta_slow" `
2023-08-09 06:38:51 +00:00
}
2023-08-07 20:39:10 +00:00
2023-08-09 06:38:51 +00:00
type AutoGPTQ struct {
2023-08-07 23:10:05 +00:00
ModelBaseName string ` yaml:"model_base_name" `
Device string ` yaml:"device" `
Triton bool ` yaml:"triton" `
UseFastTokenizer bool ` yaml:"use_fast_tokenizer" `
2023-07-14 23:19:43 +00:00
}
type Functions struct {
DisableNoAction bool ` yaml:"disable_no_action" `
NoActionFunctionName string ` yaml:"no_action_function_name" `
NoActionDescriptionName string ` yaml:"no_action_description_name" `
}
type TemplateConfig struct {
2023-07-22 15:31:39 +00:00
Chat string ` yaml:"chat" `
ChatMessage string ` yaml:"chat_message" `
Completion string ` yaml:"completion" `
Edit string ` yaml:"edit" `
Functions string ` yaml:"function" `
2023-07-14 23:19:43 +00:00
}
func ( c * Config ) SetFunctionCallString ( s string ) {
c . functionCallString = s
}
func ( c * Config ) SetFunctionCallNameString ( s string ) {
c . functionCallNameString = s
}
func ( c * Config ) ShouldUseFunctions ( ) bool {
return ( ( c . functionCallString != "none" || c . functionCallString == "" ) || c . ShouldCallSpecificFunction ( ) )
}
func ( c * Config ) ShouldCallSpecificFunction ( ) bool {
return len ( c . functionCallNameString ) > 0
}
func ( c * Config ) FunctionToCall ( ) string {
return c . functionCallNameString
}
func defaultPredictOptions ( modelFile string ) PredictionOptions {
return PredictionOptions {
TopP : 0.7 ,
TopK : 80 ,
Maxtokens : 512 ,
Temperature : 0.9 ,
Model : modelFile ,
}
}
func DefaultConfig ( modelFile string ) * Config {
return & Config {
PredictionOptions : defaultPredictOptions ( modelFile ) ,
}
}
func ReadConfigFile ( file string ) ( [ ] * Config , error ) {
c := & [ ] * Config { }
f , err := os . ReadFile ( file )
if err != nil {
return nil , fmt . Errorf ( "cannot read config file: %w" , err )
}
if err := yaml . Unmarshal ( f , c ) ; err != nil {
return nil , fmt . Errorf ( "cannot unmarshal config file: %w" , err )
}
return * c , nil
}
2024-01-05 14:34:56 +00:00
func ReadSingleConfigFile ( file string ) ( * Config , error ) {
2023-07-14 23:19:43 +00:00
c := & Config { }
f , err := os . ReadFile ( file )
if err != nil {
return nil , fmt . Errorf ( "cannot read config file: %w" , err )
}
if err := yaml . Unmarshal ( f , c ) ; err != nil {
return nil , fmt . Errorf ( "cannot unmarshal config file: %w" , err )
}
return c , nil
}
2024-01-05 14:34:56 +00:00
func UpdateConfigFromOpenAIRequest ( config * Config , input * OpenAIRequest ) {
if input . Echo {
config . Echo = input . Echo
}
if input . TopK != 0 {
config . TopK = input . TopK
}
if input . TopP != 0 {
config . TopP = input . TopP
2023-07-14 23:19:43 +00:00
}
2024-01-05 14:34:56 +00:00
if input . Backend != "" {
config . Backend = input . Backend
2023-07-14 23:19:43 +00:00
}
2024-01-05 14:34:56 +00:00
if input . ClipSkip != 0 {
config . Diffusers . ClipSkip = input . ClipSkip
2023-07-14 23:19:43 +00:00
}
2024-01-05 14:34:56 +00:00
if input . ModelBaseName != "" {
config . AutoGPTQ . ModelBaseName = input . ModelBaseName
}
2023-07-14 23:19:43 +00:00
2024-01-05 14:34:56 +00:00
if input . NegativePromptScale != 0 {
config . NegativePromptScale = input . NegativePromptScale
}
2023-07-14 23:19:43 +00:00
2024-01-05 14:34:56 +00:00
if input . UseFastTokenizer {
config . UseFastTokenizer = input . UseFastTokenizer
2023-07-31 17:14:32 +00:00
}
2024-01-05 14:34:56 +00:00
if input . NegativePrompt != "" {
config . NegativePrompt = input . NegativePrompt
2023-07-14 23:19:43 +00:00
}
2024-01-05 14:34:56 +00:00
if input . RopeFreqBase != 0 {
config . RopeFreqBase = input . RopeFreqBase
}
2023-12-18 17:58:44 +00:00
2024-01-05 14:34:56 +00:00
if input . RopeFreqScale != 0 {
config . RopeFreqScale = input . RopeFreqScale
2024-01-01 13:39:13 +00:00
}
2024-01-05 14:34:56 +00:00
if input . Grammar != "" {
config . Grammar = input . Grammar
}
2023-12-30 14:36:46 +00:00
2024-01-05 14:34:56 +00:00
if input . Temperature != 0 {
config . Temperature = input . Temperature
}
2023-12-30 14:36:46 +00:00
2024-01-05 14:34:56 +00:00
if input . Maxtokens != 0 {
config . Maxtokens = input . Maxtokens
}
2024-01-01 13:39:13 +00:00
2024-01-05 14:34:56 +00:00
if input . RepeatPenalty != 0 {
config . RepeatPenalty = input . RepeatPenalty
}
2024-01-01 13:39:13 +00:00
2024-01-05 14:34:56 +00:00
if input . Keep != 0 {
config . Keep = input . Keep
}
if input . Batch != 0 {
config . Batch = input . Batch
}
if input . F16 {
config . F16 = input . F16
}
if input . IgnoreEOS {
config . IgnoreEOS = input . IgnoreEOS
}
if input . Seed != 0 {
config . Seed = input . Seed
}
if input . Mirostat != 0 {
config . LLMConfig . Mirostat = input . Mirostat
}
if input . MirostatETA != 0 {
config . LLMConfig . MirostatETA = input . MirostatETA
}
if input . MirostatTAU != 0 {
config . LLMConfig . MirostatTAU = input . MirostatTAU
}
if input . TypicalP != 0 {
config . TypicalP = input . TypicalP
}
switch stop := input . Stop . ( type ) {
case string :
if stop != "" {
config . StopWords = append ( config . StopWords , stop )
}
case [ ] interface { } :
for _ , pp := range stop {
if s , ok := pp . ( string ) ; ok {
config . StopWords = append ( config . StopWords , s )
2024-01-01 13:39:13 +00:00
}
}
2024-01-05 14:34:56 +00:00
}
2024-01-01 13:39:13 +00:00
2024-01-05 14:34:56 +00:00
// Decode each request's message content
index := 0
for i , m := range input . Messages {
switch content := m . Content . ( type ) {
case string :
input . Messages [ i ] . StringContent = content
case [ ] interface { } :
dat , _ := json . Marshal ( content )
c := [ ] Content { }
json . Unmarshal ( dat , & c )
for _ , pp := range c {
if pp . Type == "text" {
input . Messages [ i ] . StringContent = pp . Text
} else if pp . Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64 , err := utils . GetBase64Image ( pp . ImageURL . URL )
if err == nil {
input . Messages [ i ] . StringImages = append ( input . Messages [ i ] . StringImages , base64 ) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input . Messages [ i ] . StringContent = fmt . Sprintf ( "[img-%d]" , index ) + input . Messages [ i ] . StringContent
index ++
} else {
fmt . Print ( "Failed encoding image" , err )
}
2023-12-18 17:58:44 +00:00
}
}
}
}
2024-01-05 14:34:56 +00:00
// TODO: check that this was merged correctly? I _think_ it is?
switch inputs := input . Input . ( type ) {
case string :
if inputs != "" {
config . InputStrings = append ( config . InputStrings , inputs )
}
case [ ] interface { } :
for _ , pp := range inputs {
switch i := pp . ( type ) {
case string :
config . InputStrings = append ( config . InputStrings , i )
case [ ] interface { } :
tokens := [ ] int { }
for _ , ii := range i {
tokens = append ( tokens , int ( ii . ( float64 ) ) )
}
config . InputToken = append ( config . InputToken , tokens )
}
2023-07-14 23:19:43 +00:00
}
}
2024-01-05 14:34:56 +00:00
// Can be either a string or an object
switch fnc := input . FunctionCall . ( type ) {
case string :
if fnc != "" {
config . SetFunctionCallString ( fnc )
}
case map [ string ] interface { } :
var name string
n , exists := fnc [ "name" ]
if exists {
nn , e := n . ( string )
if e {
name = nn
}
2023-07-14 23:19:43 +00:00
}
2024-01-05 14:34:56 +00:00
config . SetFunctionCallNameString ( name )
}
switch p := input . Prompt . ( type ) {
case string :
config . PromptStrings = append ( config . PromptStrings , p )
case [ ] interface { } :
for _ , pp := range p {
if s , ok := pp . ( string ) ; ok {
config . PromptStrings = append ( config . PromptStrings , s )
}
2023-07-14 23:19:43 +00:00
}
}
}