From e54ceaa497ff8f6bc82ec6304e02bb37c60730a0 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 31 Mar 2023 20:51:27 +0000 Subject: [PATCH] Fix issue with stale connections to removed LB server Track LB connections through each server so that they can be closed when it is removed. Signed-off-by: Brad Davidson --- pkg/agent/loadbalancer/loadbalancer.go | 63 +++++++++++++++---- pkg/agent/loadbalancer/loadbalancer_test.go | 12 ++-- pkg/agent/loadbalancer/servers.go | 70 ++++++++++++++++++++- 3 files changed, 124 insertions(+), 21 deletions(-) diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index efd64e445b..f47f4c38a3 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -14,10 +14,25 @@ import ( "inet.af/tcpproxy" ) +// server tracks the connections to a server, so that they can be closed when the server is removed. +type server struct { + mutex sync.Mutex + connections map[net.Conn]struct{} +} + +// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. +type serverConn struct { + server *server + net.Conn +} + +// LoadBalancer holds data for a local listener which forwards connections to a +// pool of remote servers. It is not a proper load-balancer in that it does not +// actually balance connections, but instead fails over to a new server only +// when a connection attempt to the currently selected server fails. type LoadBalancer struct { - mutex sync.Mutex - dialer *net.Dialer - proxy *tcpproxy.Proxy + mutex sync.Mutex + proxy *tcpproxy.Proxy serviceName string configFile string @@ -27,6 +42,7 @@ type LoadBalancer struct { ServerURL string ServerAddresses []string randomServers []string + servers map[string]*server currentServerAddress string nextServerIndex int Listener net.Listener @@ -40,6 +56,8 @@ var ( ETCDServerServiceName = version.Program + "-etcd-server-load-balancer" ) +// New contstructs a new LoadBalancer instance. The default server URL, and +// currently active servers, are stored in a file within the dataDir. func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { config := net.ListenConfig{Control: reusePort} var localAddress string @@ -76,11 +94,11 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo lb := &LoadBalancer{ serviceName: serviceName, - dialer: &net.Dialer{}, configFile: filepath.Join(dataDir, "etc", serviceName+".json"), localAddress: localAddress, localServerURL: localServerURL, defaultServerAddress: defaultServerAddress, + servers: make(map[string]*server), ServerURL: serverURL, } @@ -103,14 +121,28 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo if err := lb.proxy.Start(); err != nil { return nil, err } - logrus.Infof("Running load balancer %s %s -> %v", serviceName, lb.localAddress, lb.randomServers) + logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress) return lb, nil } func (lb *LoadBalancer) SetDefault(serverAddress string) { - logrus.Infof("Updating load balancer %s default server address -> %s", lb.serviceName, serverAddress) + lb.mutex.Lock() + defer lb.mutex.Unlock() + + _, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress) + // if the old default server is not currently in use, remove it from the server map + if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer { + defer server.closeAll() + delete(lb.servers, lb.defaultServerAddress) + } + // if the new default server doesn't have an entry in the map, add one + if _, ok := lb.servers[serverAddress]; !ok { + lb.servers[serverAddress] = &server{connections: make(map[net.Conn]struct{})} + } + lb.defaultServerAddress = serverAddress + logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) } func (lb *LoadBalancer) Update(serverAddresses []string) { @@ -120,7 +152,7 @@ func (lb *LoadBalancer) Update(serverAddresses []string) { if !lb.setServers(serverAddresses) { return } - logrus.Infof("Updating load balancer %s server addresses -> %v", lb.serviceName, lb.randomServers) + logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.ServerAddresses, lb.defaultServerAddress) if err := lb.writeConfig(); err != nil { logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) @@ -139,18 +171,23 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string for { targetServer := lb.currentServerAddress - conn, err := lb.dialer.DialContext(ctx, network, targetServer) - if err == nil { - return conn, nil + server := lb.servers[targetServer] + if server == nil || targetServer == "" { + logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer) + } else { + conn, err := server.dialContext(ctx, network, targetServer) + if err == nil { + return conn, nil + } + logrus.Debugf("Dial error from load balancer %s: %s", lb.serviceName, err) } - logrus.Debugf("Dial error from load balancer %s: %s", lb.serviceName, err) newServer, err := lb.nextServer(targetServer) if err != nil { return nil, err } if targetServer != newServer { - logrus.Debugf("Dial server in load balancer %s failed over to %s", lb.serviceName, newServer) + logrus.Debugf("Failed over to new server for load balancer %s: %s", lb.serviceName, newServer) } if ctx.Err() != nil { return nil, ctx.Err() @@ -167,7 +204,7 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string } func onDialError(src net.Conn, dstDialErr error) { - logrus.Debugf("Incoming conn %v, error dialing load balancer servers: %v", src.RemoteAddr().String(), dstDialErr) + logrus.Debugf("Incoming conn %s, error dialing load balancer servers: %v", src.RemoteAddr(), dstDialErr) src.Close() } diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index d36c503e5e..e91f52760e 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -15,18 +15,18 @@ import ( "github.com/k3s-io/k3s/pkg/cli/cmds" ) -type server struct { +type testServer struct { listener net.Listener conns []net.Conn prefix string } -func createServer(prefix string) (*server, error) { +func createServer(prefix string) (*testServer, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err } - s := &server{ + s := &testServer{ prefix: prefix, listener: listener, } @@ -34,7 +34,7 @@ func createServer(prefix string) (*server, error) { return s, nil } -func (s *server) serve() { +func (s *testServer) serve() { for { conn, err := s.listener.Accept() if err != nil { @@ -45,14 +45,14 @@ func (s *server) serve() { } } -func (s *server) close() { +func (s *testServer) close() { s.listener.Close() for _, conn := range s.conns { conn.Close() } } -func (s *server) echo(conn net.Conn) { +func (s *testServer) echo(conn net.Conn) { for { result, err := bufio.NewReader(conn).ReadString('\n') if err != nil { diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index e0de81034b..e8e9f83a15 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -1,11 +1,17 @@ package loadbalancer import ( + "context" "errors" "math/rand" - "reflect" + "net" + + "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/util/sets" ) +var defaultDialer = &net.Dialer{} + func (lb *LoadBalancer) setServers(serverAddresses []string) bool { serverAddresses, hasOriginalServer := sortServers(serverAddresses, lb.defaultServerAddress) if len(serverAddresses) == 0 { @@ -15,10 +21,32 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool { lb.mutex.Lock() defer lb.mutex.Unlock() - if reflect.DeepEqual(serverAddresses, lb.ServerAddresses) { + newAddresses := sets.NewString(serverAddresses...) + curAddresses := sets.NewString(lb.ServerAddresses...) + if newAddresses.Equal(curAddresses) { return false } + for addedServer := range newAddresses.Difference(curAddresses) { + logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) + lb.servers[addedServer] = &server{connections: make(map[net.Conn]struct{})} + } + + for removedServer := range curAddresses.Difference(newAddresses) { + server := lb.servers[removedServer] + if server != nil { + logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) + // Defer closing connections until after the new server list has been put into place. + // Closing open connections ensures that anything stuck retrying on a stale server is forced + // over to a valid endpoint. + defer server.closeAll() + // Don't delete the default server from the server map, in case we need to fall back to it. + if removedServer != lb.defaultServerAddress { + delete(lb.servers, removedServer) + } + } + } + lb.ServerAddresses = serverAddresses lb.randomServers = append([]string{}, lb.ServerAddresses...) rand.Shuffle(len(lb.randomServers), func(i, j int) { @@ -55,3 +83,41 @@ func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { return lb.currentServerAddress, nil } + +// dialContext dials a new connection, and adds its wrapped connection to the map +func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := defaultDialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + // don't lock until adding the connection to the map, otherwise we may block + // while waiting for the dial to time out + s.mutex.Lock() + defer s.mutex.Unlock() + + conn = &serverConn{server: s, Conn: conn} + s.connections[conn] = struct{}{} + return conn, nil +} + +// closeAll closes all connections to the server, and removes their entries from the map +func (s *server) closeAll() { + s.mutex.Lock() + defer s.mutex.Unlock() + + logrus.Debugf("Closing %d connections to load balancer server", len(s.connections)) + for conn := range s.connections { + // Close the connection in a goroutine so that we don't hold the lock while doing so. + go conn.Close() + } +} + +// Close removes the connection entry from the server's connection map, and +// closes the wrapped connection. +func (sc *serverConn) Close() error { + sc.server.mutex.Lock() + defer sc.server.mutex.Unlock() + + delete(sc.server.connections, sc) + return sc.Conn.Close() +}