mirror of
https://github.com/k3s-io/k3s.git
synced 2024-06-07 19:41:36 +00:00
e54ceaa497
Track LB connections through each server so that they can be closed when it is removed. Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
219 lines
6.3 KiB
Go
219 lines
6.3 KiB
Go
package loadbalancer
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
|
|
"github.com/k3s-io/k3s/pkg/version"
|
|
"github.com/sirupsen/logrus"
|
|
"inet.af/tcpproxy"
|
|
)
|
|
|
|
// server tracks the connections to a server, so that they can be closed when the server is removed.
|
|
type server struct {
|
|
mutex sync.Mutex
|
|
connections map[net.Conn]struct{}
|
|
}
|
|
|
|
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
|
|
type serverConn struct {
|
|
server *server
|
|
net.Conn
|
|
}
|
|
|
|
// LoadBalancer holds data for a local listener which forwards connections to a
|
|
// pool of remote servers. It is not a proper load-balancer in that it does not
|
|
// actually balance connections, but instead fails over to a new server only
|
|
// when a connection attempt to the currently selected server fails.
|
|
type LoadBalancer struct {
|
|
mutex sync.Mutex
|
|
proxy *tcpproxy.Proxy
|
|
|
|
serviceName string
|
|
configFile string
|
|
localAddress string
|
|
localServerURL string
|
|
defaultServerAddress string
|
|
ServerURL string
|
|
ServerAddresses []string
|
|
randomServers []string
|
|
servers map[string]*server
|
|
currentServerAddress string
|
|
nextServerIndex int
|
|
Listener net.Listener
|
|
}
|
|
|
|
const RandomPort = 0
|
|
|
|
var (
|
|
SupervisorServiceName = version.Program + "-agent-load-balancer"
|
|
APIServerServiceName = version.Program + "-api-server-agent-load-balancer"
|
|
ETCDServerServiceName = version.Program + "-etcd-server-load-balancer"
|
|
)
|
|
|
|
// New contstructs a new LoadBalancer instance. The default server URL, and
|
|
// currently active servers, are stored in a file within the dataDir.
|
|
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
|
|
config := net.ListenConfig{Control: reusePort}
|
|
var localAddress string
|
|
if isIPv6 {
|
|
localAddress = fmt.Sprintf("[::1]:%d", lbServerPort)
|
|
} else {
|
|
localAddress = fmt.Sprintf("127.0.0.1:%d", lbServerPort)
|
|
}
|
|
listener, err := config.Listen(ctx, "tcp", localAddress)
|
|
defer func() {
|
|
if _err != nil {
|
|
logrus.Warnf("Error starting load balancer: %s", _err)
|
|
if listener != nil {
|
|
listener.Close()
|
|
}
|
|
}
|
|
}()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with.
|
|
localAddress = listener.Addr().String()
|
|
|
|
defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if serverURL == localServerURL {
|
|
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
|
|
defaultServerAddress = ""
|
|
}
|
|
|
|
lb := &LoadBalancer{
|
|
serviceName: serviceName,
|
|
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
|
|
localAddress: localAddress,
|
|
localServerURL: localServerURL,
|
|
defaultServerAddress: defaultServerAddress,
|
|
servers: make(map[string]*server),
|
|
ServerURL: serverURL,
|
|
}
|
|
|
|
lb.setServers([]string{lb.defaultServerAddress})
|
|
|
|
lb.proxy = &tcpproxy.Proxy{
|
|
ListenFunc: func(string, string) (net.Listener, error) {
|
|
return listener, nil
|
|
},
|
|
}
|
|
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
|
|
Addr: serviceName,
|
|
DialContext: lb.dialContext,
|
|
OnDialError: onDialError,
|
|
})
|
|
|
|
if err := lb.updateConfig(); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := lb.proxy.Start(); err != nil {
|
|
return nil, err
|
|
}
|
|
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress)
|
|
|
|
return lb, nil
|
|
}
|
|
|
|
func (lb *LoadBalancer) SetDefault(serverAddress string) {
|
|
lb.mutex.Lock()
|
|
defer lb.mutex.Unlock()
|
|
|
|
_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
|
|
// if the old default server is not currently in use, remove it from the server map
|
|
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
|
|
defer server.closeAll()
|
|
delete(lb.servers, lb.defaultServerAddress)
|
|
}
|
|
// if the new default server doesn't have an entry in the map, add one
|
|
if _, ok := lb.servers[serverAddress]; !ok {
|
|
lb.servers[serverAddress] = &server{connections: make(map[net.Conn]struct{})}
|
|
}
|
|
|
|
lb.defaultServerAddress = serverAddress
|
|
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
|
|
}
|
|
|
|
func (lb *LoadBalancer) Update(serverAddresses []string) {
|
|
if lb == nil {
|
|
return
|
|
}
|
|
if !lb.setServers(serverAddresses) {
|
|
return
|
|
}
|
|
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.ServerAddresses, lb.defaultServerAddress)
|
|
|
|
if err := lb.writeConfig(); err != nil {
|
|
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
|
|
}
|
|
}
|
|
|
|
func (lb *LoadBalancer) LoadBalancerServerURL() string {
|
|
if lb == nil {
|
|
return ""
|
|
}
|
|
return lb.localServerURL
|
|
}
|
|
|
|
func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
startIndex := lb.nextServerIndex
|
|
for {
|
|
targetServer := lb.currentServerAddress
|
|
|
|
server := lb.servers[targetServer]
|
|
if server == nil || targetServer == "" {
|
|
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
|
|
} else {
|
|
conn, err := server.dialContext(ctx, network, targetServer)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
logrus.Debugf("Dial error from load balancer %s: %s", lb.serviceName, err)
|
|
}
|
|
|
|
newServer, err := lb.nextServer(targetServer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if targetServer != newServer {
|
|
logrus.Debugf("Failed over to new server for load balancer %s: %s", lb.serviceName, newServer)
|
|
}
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
maxIndex := len(lb.randomServers)
|
|
if startIndex > maxIndex {
|
|
startIndex = maxIndex
|
|
}
|
|
if lb.nextServerIndex == startIndex {
|
|
return nil, errors.New("all servers failed")
|
|
}
|
|
}
|
|
}
|
|
|
|
func onDialError(src net.Conn, dstDialErr error) {
|
|
logrus.Debugf("Incoming conn %s, error dialing load balancer servers: %v", src.RemoteAddr(), dstDialErr)
|
|
src.Close()
|
|
}
|
|
|
|
// ResetLoadBalancer will delete the local state file for the load balancer on disk
|
|
func ResetLoadBalancer(dataDir, serviceName string) error {
|
|
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
|
|
if err := os.Remove(stateFile); err != nil {
|
|
logrus.Warn(err)
|
|
}
|
|
return nil
|
|
}
|