diff --git a/pkg/agent/run.go b/pkg/agent/run.go index 640e19c26a..9a4521149f 100644 --- a/pkg/agent/run.go +++ b/pkg/agent/run.go @@ -47,7 +47,7 @@ func run(ctx context.Context, cfg cmds.Agent) error { return err } - if err := tunnel.Setup(nodeConfig); err != nil { + if err := tunnel.Setup(ctx, nodeConfig); err != nil { return err } diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index badc519ab9..f7884591c5 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -52,7 +52,7 @@ func getAddresses(endpoint *v1.Endpoints) []string { return serverAddresses } -func Setup(config *config.Node) error { +func Setup(ctx context.Context, config *config.Node) error { restConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigNode) if err != nil { return err @@ -83,7 +83,6 @@ func Setup(config *config.Node) error { disconnect[address] = connect(wg, address, config, transportConfig) } } - wg.Wait() go func() { connect: @@ -134,6 +133,19 @@ func Setup(config *config.Node) error { } }() + wait := make(chan int, 1) + go func() { + wg.Wait() + wait <- 0 + }() + + select { + case <-ctx.Done(): + logrus.Error("tunnel context canceled while waiting for connection") + return ctx.Err() + case <-wait: + } + return nil } @@ -178,6 +190,9 @@ func connect(waitGroup *sync.WaitGroup, address string, config *config.Node, tra }) if ctx.Err() != nil { + if waitGroup != nil { + once.Do(waitGroup.Done) + } logrus.Infof("Stopped tunnel to %s", wsURL) return }