diff --git a/pkg/etcd/etcd.go b/pkg/etcd/etcd.go index 587801d8a7..8ec60d7cc9 100644 --- a/pkg/etcd/etcd.go +++ b/pkg/etcd/etcd.go @@ -50,6 +50,7 @@ import ( const ( defaultEndpoint = "https://127.0.0.1:2379" + defaultEndpointv6 = "https://[::1]:2379" testTimeout = time.Second * 10 manageTickerTime = time.Second * 15 learnerMaxStallTime = time.Minute * 5 @@ -150,7 +151,7 @@ func (e *ETCD) EndpointName() string { func (e *ETCD) SetControlConfig(ctx context.Context, config *config.Control) error { e.config = config - client, err := GetClient(ctx, e.config.Runtime) + client, err := GetClient(ctx, e.config) if err != nil { return err } @@ -178,7 +179,7 @@ func (e *ETCD) Test(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() - endpoints := getEndpoints(e.config.Runtime) + endpoints := getEndpoints(e.config) status, err := e.client.Status(ctx, endpoints[0]) if err != nil { return err @@ -424,7 +425,7 @@ func (e *ETCD) join(ctx context.Context, clientAccessInfo *clientaccess.Info) er return err } - client, err := GetClient(clientCtx, e.config.Runtime, clientURLs...) + client, err := GetClient(clientCtx, e.config, clientURLs...) if err != nil { return err } @@ -501,7 +502,7 @@ func (e *ETCD) join(ctx context.Context, clientAccessInfo *clientaccess.Info) er func (e *ETCD) Register(ctx context.Context, config *config.Control, handler http.Handler) (http.Handler, error) { e.config = config - client, err := GetClient(ctx, e.config.Runtime) + client, err := GetClient(ctx, e.config) if err != nil { return nil, err } @@ -518,7 +519,7 @@ func (e *ETCD) Register(ctx context.Context, config *config.Control, handler htt } e.address = address - endpoints := getEndpoints(config.Runtime) + endpoints := getEndpoints(config) e.config.Datastore.Endpoint = endpoints[0] e.config.Datastore.BackendTLSConfig.CAFile = e.config.Runtime.ETCDServerCA e.config.Datastore.BackendTLSConfig.CertFile = e.config.Runtime.ClientETCDCert @@ -609,8 +610,8 @@ func (e *ETCD) infoHandler() http.Handler { // If the runtime config does not list any endpoints, the default endpoint is used. // The returned client should be closed when no longer needed, in order to avoid leaking GRPC // client goroutines. -func GetClient(ctx context.Context, runtime *config.ControlRuntime, endpoints ...string) (*clientv3.Client, error) { - cfg, err := getClientConfig(ctx, runtime, endpoints...) +func GetClient(ctx context.Context, control *config.Control, endpoints ...string) (*clientv3.Client, error) { + cfg, err := getClientConfig(ctx, control, endpoints...) if err != nil { return nil, err } @@ -620,9 +621,10 @@ func GetClient(ctx context.Context, runtime *config.ControlRuntime, endpoints .. // getClientConfig generates an etcd client config connected to the specified endpoints. // If no endpoints are provided, getEndpoints is called to provide defaults. -func getClientConfig(ctx context.Context, runtime *config.ControlRuntime, endpoints ...string) (*clientv3.Config, error) { +func getClientConfig(ctx context.Context, control *config.Control, endpoints ...string) (*clientv3.Config, error) { + runtime := control.Runtime if len(endpoints) == 0 { - endpoints = getEndpoints(runtime) + endpoints = getEndpoints(control) } config := &clientv3.Config{ @@ -641,10 +643,14 @@ func getClientConfig(ctx context.Context, runtime *config.ControlRuntime, endpoi } // getEndpoints returns the endpoints from the runtime config if set, otherwise the default endpoint. -func getEndpoints(runtime *config.ControlRuntime) []string { +func getEndpoints(control *config.Control) []string { + runtime := control.Runtime if len(runtime.EtcdConfig.Endpoints) > 0 { return runtime.EtcdConfig.Endpoints } + if utilsnet.IsIPv6String(control.PrivateIP) { + return []string{defaultEndpointv6} + } return []string{defaultEndpoint} } @@ -730,7 +736,7 @@ func (e *ETCD) migrateFromSQLite(ctx context.Context) error { } defer sqliteClient.Close() - etcdClient, err := GetClient(ctx, e.config.Runtime) + etcdClient, err := GetClient(ctx, e.config) if err != nil { return err } @@ -827,7 +833,7 @@ func (e *ETCD) StartEmbeddedTemporary(ctx context.Context) error { return err } - endpoints := getEndpoints(e.config.Runtime) + endpoints := getEndpoints(e.config) clientURL := endpoints[0] peerURL, err := addPort(endpoints[0], 1) if err != nil { @@ -915,7 +921,7 @@ func (e *ETCD) manageLearners(ctx context.Context) { logrus.Debug("Etcd client was nil") continue } - endpoints := getEndpoints(e.config.Runtime) + endpoints := getEndpoints(e.config) if status, err := e.client.Status(ctx, endpoints[0]); err != nil { logrus.Errorf("Failed to check local etcd status for learner management: %v", err) continue @@ -1071,7 +1077,7 @@ func (e *ETCD) defragment(ctx context.Context) error { } logrus.Infof("Defragmenting etcd database") - endpoints := getEndpoints(e.config.Runtime) + endpoints := getEndpoints(e.config) _, err := e.client.Defragment(ctx, endpoints[0]) return err } @@ -1144,7 +1150,7 @@ func (e *ETCD) preSnapshotSetup(ctx context.Context, config *config.Control) err if e.config == nil { e.config = config } - client, err := GetClient(ctx, e.config.Runtime) + client, err := GetClient(ctx, e.config) if err != nil { return err } @@ -1272,7 +1278,7 @@ func (e *ETCD) Snapshot(ctx context.Context, config *config.Control) error { } } - endpoints := getEndpoints(e.config.Runtime) + endpoints := getEndpoints(e.config) status, err := e.client.Status(ctx, endpoints[0]) if err != nil { return errors.Wrap(err, "failed to check etcd status for snapshot") @@ -1288,7 +1294,7 @@ func (e *ETCD) Snapshot(ctx context.Context, config *config.Control) error { return errors.Wrap(err, "failed to get the snapshot dir") } - cfg, err := getClientConfig(ctx, e.config.Runtime) + cfg, err := getClientConfig(ctx, e.config) if err != nil { return errors.Wrap(err, "failed to get config for etcd snapshot") } @@ -2013,7 +2019,7 @@ func backupDirWithRetention(dir string, maxBackupRetention int) (string, error) // GetAPIServerURLsFromETCD will try to fetch the version.Program/apiaddresses key from etcd // and unmarshal it to a list of apiserver endpoints. func GetAPIServerURLsFromETCD(ctx context.Context, cfg *config.Control) ([]string, error) { - cl, err := GetClient(ctx, cfg.Runtime) + cl, err := GetClient(ctx, cfg) if err != nil { return nil, err } diff --git a/pkg/etcd/etcd_test.go b/pkg/etcd/etcd_test.go index cff23e2f9e..0e7094e012 100644 --- a/pkg/etcd/etcd_test.go +++ b/pkg/etcd/etcd_test.go @@ -244,7 +244,7 @@ func Test_UnitETCD_Start(t *testing.T) { ctxInfo.ctx, ctxInfo.cancel = context.WithCancel(context.Background()) e.config.EtcdDisableSnapshots = true testutil.GenerateRuntime(e.config) - client, err := GetClient(ctxInfo.ctx, e.config.Runtime) + client, err := GetClient(ctxInfo.ctx, e.config) e.client = client return err @@ -275,7 +275,7 @@ func Test_UnitETCD_Start(t *testing.T) { setup: func(e *ETCD, ctxInfo *contextInfo) error { ctxInfo.ctx, ctxInfo.cancel = context.WithCancel(context.Background()) testutil.GenerateRuntime(e.config) - client, err := GetClient(ctxInfo.ctx, e.config.Runtime) + client, err := GetClient(ctxInfo.ctx, e.config) e.client = client return err @@ -308,7 +308,7 @@ func Test_UnitETCD_Start(t *testing.T) { if err := testutil.GenerateRuntime(e.config); err != nil { return err } - client, err := GetClient(ctxInfo.ctx, e.config.Runtime) + client, err := GetClient(ctxInfo.ctx, e.config) if err != nil { return err }