diff --git a/.env b/.env index c4691deb..a31e1e37 100644 --- a/.env +++ b/.env @@ -72,4 +72,18 @@ MODELS_PATH=/models # LLAMACPP_PARALLEL=1 ### Enable to run parallel requests -# PARALLEL_REQUESTS=true \ No newline at end of file +# PARALLEL_REQUESTS=true + +### Watchdog settings +### +# Enables watchdog to kill backends that are inactive for too much time +# WATCHDOG_IDLE=true +# +# Enables watchdog to kill backends that are busy for too much time +# WATCHDOG_BUSY=true +# +# Time in duration format (e.g. 1h30m) after which a backend is considered idle +# WATCHDOG_IDLE_TIMEOUT=5m +# +# Time in duration format (e.g. 1h30m) after which a backend is considered busy +# WATCHDOG_BUSY_TIMEOUT=5m \ No newline at end of file diff --git a/api/api.go b/api/api.go index 1da844f9..9a097838 100644 --- a/api/api.go +++ b/api/api.go @@ -13,6 +13,7 @@ import ( "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/assets" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" @@ -79,6 +80,22 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, options.Loader.StopAllGRPC() }() + if options.WatchDog { + wd := model.NewWatchDog( + options.Loader, + options.WatchDogBusyTimeout, + options.WatchDogIdleTimeout, + options.WatchDogBusy, + options.WatchDogIdle) + options.Loader.SetWatchDog(wd) + go wd.Run() + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + wd.Shutdown() + }() + } + return options, cl, nil } diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go index 3aea8a84..8cb0bb45 100644 --- a/api/localai/backend_monitor.go +++ b/api/localai/backend_monitor.go @@ -128,7 +128,7 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { return fmt.Errorf("backend %s is not currently loaded", backendId) } - status, rpcErr := model.GRPC(false).Status(context.TODO()) + status, rpcErr := model.GRPC(false, nil).Status(context.TODO()) if rpcErr != nil { log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) val, slbErr := bm.SampleLocalBackendProcess(backendId) diff --git a/api/options/options.go b/api/options/options.go index 9488e549..127d06f0 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -4,6 +4,7 @@ import ( "context" "embed" "encoding/json" + "time" "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" @@ -38,6 +39,11 @@ type Option struct { SingleBackend bool ParallelBackendRequests bool + + WatchDogIdle bool + WatchDogBusy bool + WatchDog bool + WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration } type AppOption func(*Option) @@ -63,6 +69,32 @@ func WithCors(b bool) AppOption { } } +var EnableWatchDog = func(o *Option) { + o.WatchDog = true +} + +var EnableWatchDogIdleCheck = func(o *Option) { + o.WatchDog = true + o.WatchDogIdle = true +} + +var EnableWatchDogBusyCheck = func(o *Option) { + o.WatchDog = true + o.WatchDogBusy = true +} + +func SetWatchDogBusyTimeout(t time.Duration) AppOption { + return func(o *Option) { + o.WatchDogBusyTimeout = t + } +} + +func SetWatchDogIdleTimeout(t time.Duration) AppOption { + return func(o *Option) { + o.WatchDogIdleTimeout = t + } +} + var EnableSingleBackend = func(o *Option) { o.SingleBackend = true } diff --git a/main.go b/main.go index bc9d1bba..97b258c0 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "syscall" + "time" api "github.com/go-skynet/LocalAI/api" "github.com/go-skynet/LocalAI/api/backend" @@ -154,6 +155,30 @@ func main() { Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.", EnvVars: []string{"API_KEY"}, }, + &cli.BoolFlag{ + Name: "enable-watchdog-idle", + Usage: "Enable watchdog for stopping idle backends. This will stop the backends if are in idle state for too long.", + EnvVars: []string{"WATCHDOG_IDLE"}, + Value: false, + }, + &cli.BoolFlag{ + Name: "enable-watchdog-busy", + Usage: "Enable watchdog for stopping busy backends that exceed a defined threshold.", + EnvVars: []string{"WATCHDOG_BUSY"}, + Value: false, + }, + &cli.StringFlag{ + Name: "watchdog-busy-timeout", + Usage: "Watchdog timeout. This will restart the backend if it crashes.", + EnvVars: []string{"WATCHDOG_BUSY_TIMEOUT"}, + Value: "5m", + }, + &cli.StringFlag{ + Name: "watchdog-idle-timeout", + Usage: "Watchdog idle timeout. This will restart the backend if it crashes.", + EnvVars: []string{"WATCHDOG_IDLE_TIMEOUT"}, + Value: "15m", + }, &cli.BoolFlag{ Name: "preload-backend-only", Usage: "If set, the api is NOT launched, and only the preloaded models / backends are started. This is intended for multi-node setups.", @@ -198,6 +223,28 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithUploadLimitMB(ctx.Int("upload-limit")), options.WithApiKeys(ctx.StringSlice("api-keys")), } + + idleWatchDog := ctx.Bool("enable-watchdog-idle") + busyWatchDog := ctx.Bool("enable-watchdog-busy") + if idleWatchDog || busyWatchDog { + opts = append(opts, options.EnableWatchDog) + if idleWatchDog { + opts = append(opts, options.EnableWatchDogIdleCheck) + dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout")) + if err != nil { + return err + } + opts = append(opts, options.SetWatchDogIdleTimeout(dur)) + } + if busyWatchDog { + opts = append(opts, options.EnableWatchDogBusyCheck) + dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout")) + if err != nil { + return err + } + opts = append(opts, options.SetWatchDogBusyTimeout(dur)) + } + } if ctx.Bool("parallel-requests") { opts = append(opts, options.EnableParallelBackendRequests) } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 673e2a54..9eab356d 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -19,12 +19,22 @@ type Client struct { parallel bool sync.Mutex opMutex sync.Mutex + wd WatchDog } -func NewClient(address string, parallel bool) *Client { +type WatchDog interface { + Mark(address string) + UnMark(address string) +} + +func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client { + if !enableWatchDog { + wd = nil + } return &Client{ address: address, parallel: parallel, + wd: wd, } } @@ -79,6 +89,10 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ... } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -96,6 +110,10 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -113,6 +131,10 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -129,6 +151,10 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return err @@ -164,6 +190,10 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -180,6 +210,10 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -196,6 +230,10 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -232,6 +270,10 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts } c.setBusy(true) defer c.setBusy(false) + if c.wd != nil { + c.wd.Mark(c.address) + defer c.wd.UnMark(c.address) + } conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index a17163cd..e6b5934c 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -121,7 +121,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string // Wait for the service to start up ready := false for i := 0; i < o.grpcAttempts; i++ { - if client.GRPC(o.parallelRequests).HealthCheck(context.Background()) { + if client.GRPC(o.parallelRequests, ml.wd).HealthCheck(context.Background()) { log.Debug().Msgf("GRPC Service Ready") ready = true break @@ -140,7 +140,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string log.Debug().Msgf("GRPC: Loading model with options: %+v", options) - res, err := client.GRPC(o.parallelRequests).LoadModel(o.context, &options) + res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options) if err != nil { return "", fmt.Errorf("could not load model: %w", err) } @@ -154,11 +154,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) { if parallel { - return addr.GRPC(parallel), nil + return addr.GRPC(parallel, ml.wd), nil } if _, ok := ml.grpcClients[string(addr)]; !ok { - ml.grpcClients[string(addr)] = addr.GRPC(parallel) + ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd) } return ml.grpcClients[string(addr)], nil } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 60671301..493ee083 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -63,12 +63,17 @@ type ModelLoader struct { models map[string]ModelAddress grpcProcesses map[string]*process.Process templates map[TemplateType]map[string]*template.Template + wd *WatchDog } type ModelAddress string -func (m ModelAddress) GRPC(parallel bool) *grpc.Client { - return grpc.NewClient(string(m), parallel) +func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client { + enableWD := false + if wd != nil { + enableWD = true + } + return grpc.NewClient(string(m), parallel, wd, enableWD) } func NewModelLoader(modelPath string) *ModelLoader { @@ -79,10 +84,15 @@ func NewModelLoader(modelPath string) *ModelLoader { templates: make(map[TemplateType]map[string]*template.Template), grpcProcesses: make(map[string]*process.Process), } + nml.initializeTemplateMap() return nml } +func (ml *ModelLoader) SetWatchDog(wd *WatchDog) { + ml.wd = wd +} + func (ml *ModelLoader) ExistsInModelPath(s string) bool { return existsInPath(ml.ModelPath, s) } @@ -139,11 +149,17 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) ( func (ml *ModelLoader) ShutdownModel(modelName string) error { ml.mu.Lock() defer ml.mu.Unlock() + + return ml.StopModel(modelName) +} + +func (ml *ModelLoader) StopModel(modelName string) error { + defer ml.deleteProcess(modelName) if _, ok := ml.models[modelName]; !ok { return fmt.Errorf("model %s not found", modelName) } - - return ml.deleteProcess(modelName) + return nil + //return ml.deleteProcess(modelName) } func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress { @@ -153,7 +169,7 @@ func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress { if c, ok := ml.grpcClients[s]; ok { client = c } else { - client = m.GRPC(false) + client = m.GRPC(false, ml.wd) } if !client.HealthCheck(context.Background()) { diff --git a/pkg/model/process.go b/pkg/model/process.go index 18f44a66..5f63ee7f 100644 --- a/pkg/model/process.go +++ b/pkg/model/process.go @@ -17,7 +17,7 @@ import ( func (ml *ModelLoader) StopAllExcept(s string) { ml.StopGRPC(func(id string, p *process.Process) bool { if id != s { - for ml.models[id].GRPC(false).IsBusy() { + for ml.models[id].GRPC(false, ml.wd).IsBusy() { log.Debug().Msgf("%s busy. Waiting.", id) time.Sleep(2 * time.Second) } @@ -80,6 +80,11 @@ func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string process.WithEnvironment(os.Environ()...), ) + if ml.wd != nil { + ml.wd.Add(serverAddress, grpcControlProcess) + ml.wd.AddAddressModelMap(serverAddress, id) + } + ml.grpcProcesses[id] = grpcControlProcess if err := grpcControlProcess.Run(); err != nil { diff --git a/pkg/model/watchdog.go b/pkg/model/watchdog.go new file mode 100644 index 00000000..cf313180 --- /dev/null +++ b/pkg/model/watchdog.go @@ -0,0 +1,155 @@ +package model + +import ( + "sync" + "time" + + process "github.com/mudler/go-processmanager" + "github.com/rs/zerolog/log" +) + +// All GRPC Clients created by ModelLoader should have an associated injected +// watchdog that will keep track of the state of each backend (busy or not) +// and for how much time it has been busy. +// If a backend is busy for too long, the watchdog will kill the process and +// force a reload of the model +// The watchdog runs as a separate go routine, +// and the GRPC client talks to it via a channel to send status updates + +type WatchDog struct { + sync.Mutex + timetable map[string]time.Time + idleTime map[string]time.Time + timeout, idletimeout time.Duration + addressMap map[string]*process.Process + addressModelMap map[string]string + pm ProcessManager + stop chan bool + + busyCheck, idleCheck bool +} + +type ProcessManager interface { + StopModel(modelName string) error +} + +func NewWatchDog(pm ProcessManager, timeoutBusy, timeoutIdle time.Duration, busy, idle bool) *WatchDog { + return &WatchDog{ + timeout: timeoutBusy, + idletimeout: timeoutIdle, + pm: pm, + timetable: make(map[string]time.Time), + idleTime: make(map[string]time.Time), + addressMap: make(map[string]*process.Process), + busyCheck: busy, + idleCheck: idle, + addressModelMap: make(map[string]string), + } +} + +func (wd *WatchDog) Shutdown() { + wd.Lock() + defer wd.Unlock() + wd.stop <- true +} + +func (wd *WatchDog) AddAddressModelMap(address string, model string) { + wd.Lock() + defer wd.Unlock() + wd.addressModelMap[address] = model + +} +func (wd *WatchDog) Add(address string, p *process.Process) { + wd.Lock() + defer wd.Unlock() + wd.addressMap[address] = p +} + +func (wd *WatchDog) Mark(address string) { + wd.Lock() + defer wd.Unlock() + wd.timetable[address] = time.Now() + delete(wd.idleTime, address) +} + +func (wd *WatchDog) UnMark(ModelAddress string) { + wd.Lock() + defer wd.Unlock() + delete(wd.timetable, ModelAddress) + wd.idleTime[ModelAddress] = time.Now() +} + +func (wd *WatchDog) Run() { + log.Info().Msg("[WatchDog] starting watchdog") + + for { + select { + case <-wd.stop: + log.Info().Msg("[WatchDog] Stopping watchdog") + return + case <-time.After(30 * time.Second): + if !wd.busyCheck && !wd.idleCheck { + log.Info().Msg("[WatchDog] No checks enabled, stopping watchdog") + return + } + if wd.busyCheck { + wd.checkBusy() + } + if wd.idleCheck { + wd.checkIdle() + } + } + } +} + +func (wd *WatchDog) checkIdle() { + wd.Lock() + defer wd.Unlock() + log.Debug().Msg("[WatchDog] Watchdog checks for idle connections") + for address, t := range wd.idleTime { + log.Debug().Msgf("[WatchDog] %s: idle connection", address) + if time.Since(t) > wd.idletimeout { + log.Warn().Msgf("[WatchDog] Address %s is idle for too long, killing it", address) + p, ok := wd.addressModelMap[address] + if ok { + if err := wd.pm.StopModel(p); err != nil { + log.Error().Msgf("[watchdog] Error shutting down model %s: %v", p, err) + } + delete(wd.idleTime, address) + delete(wd.addressModelMap, address) + delete(wd.addressMap, address) + } else { + log.Warn().Msgf("[WatchDog] Address %s unresolvable", address) + delete(wd.idleTime, address) + } + } + } +} + +func (wd *WatchDog) checkBusy() { + wd.Lock() + defer wd.Unlock() + log.Debug().Msg("[WatchDog] Watchdog checks for busy connections") + + for address, t := range wd.timetable { + log.Debug().Msgf("[WatchDog] %s: active connection", address) + + if time.Since(t) > wd.timeout { + + model, ok := wd.addressModelMap[address] + if ok { + log.Warn().Msgf("[WatchDog] Model %s is busy for too long, killing it", model) + if err := wd.pm.StopModel(model); err != nil { + log.Error().Msgf("[watchdog] Error shutting down model %s: %v", model, err) + } + delete(wd.timetable, address) + delete(wd.addressModelMap, address) + delete(wd.addressMap, address) + } else { + log.Warn().Msgf("[WatchDog] Address %s unresolvable", address) + delete(wd.timetable, address) + } + + } + } +}