diff --git a/pkg/agent/config/config.go b/pkg/agent/config/config.go index 4ca106f433..c9669d37f2 100644 --- a/pkg/agent/config/config.go +++ b/pkg/agent/config/config.go @@ -6,6 +6,7 @@ import ( cryptorand "crypto/rand" "crypto/tls" "encoding/hex" + "encoding/pem" "fmt" "io/ioutil" sysnet "net" @@ -141,14 +142,13 @@ func getServingCert(nodeName, servingCertFile, servingKeyFile, nodePasswordFile if err != nil { return nil, err } + + servingCert, servingKey := splitCertKeyPEM(servingCert) + if err := ioutil.WriteFile(servingCertFile, servingCert, 0600); err != nil { return nil, errors.Wrapf(err, "failed to write node cert") } - servingKey, err := clientaccess.Get("/v1-k3s/serving-kubelet.key", info) - if err != nil { - return nil, err - } if err := ioutil.WriteFile(servingKeyFile, servingKey, 0600); err != nil { return nil, errors.Wrapf(err, "failed to write node key") } @@ -160,27 +160,60 @@ func getServingCert(nodeName, servingCertFile, servingKeyFile, nodePasswordFile return &cert, nil } -func getHostFile(filename string, info *clientaccess.Info) error { +func getHostFile(filename, keyFile string, info *clientaccess.Info) error { basename := filepath.Base(filename) fileBytes, err := clientaccess.Get("/v1-k3s/"+basename, info) if err != nil { return err } - if err := ioutil.WriteFile(filename, fileBytes, 0600); err != nil { - return errors.Wrapf(err, "failed to write cert %s", filename) + if keyFile == "" { + if err := ioutil.WriteFile(filename, fileBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write cert %s", filename) + } + } else { + fileBytes, keyBytes := splitCertKeyPEM(fileBytes) + if err := ioutil.WriteFile(filename, fileBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write cert %s", filename) + } + if err := ioutil.WriteFile(keyFile, keyBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write key %s", filename) + } } return nil } -func getNodeNamedHostFile(filename, nodeName, nodePasswordFile string, info *clientaccess.Info) error { +func splitCertKeyPEM(bytes []byte) (certPem []byte, keyPem []byte) { + for { + b, rest := pem.Decode(bytes) + if b == nil { + break + } + bytes = rest + + if strings.Contains(b.Type, "PRIVATE KEY") { + keyPem = append(keyPem, pem.EncodeToMemory(b)...) + } else { + certPem = append(certPem, pem.EncodeToMemory(b)...) + } + } + + return +} + +func getNodeNamedHostFile(filename, keyFile, nodeName, nodePasswordFile string, info *clientaccess.Info) error { basename := filepath.Base(filename) fileBytes, err := Request("/v1-k3s/"+basename, info, getNodeNamedCrt(nodeName, nodePasswordFile)) if err != nil { return err } + fileBytes, keyBytes := splitCertKeyPEM(fileBytes) + if err := ioutil.WriteFile(filename, fileBytes, 0600); err != nil { return errors.Wrapf(err, "failed to write cert %s", filename) } + if err := ioutil.WriteFile(keyFile, keyBytes, 0600); err != nil { + return errors.Wrapf(err, "failed to write key %s", filename) + } return nil } @@ -287,12 +320,12 @@ func get(envInfo *cmds.Agent) (*config.Node, error) { } clientCAFile := filepath.Join(envInfo.DataDir, "client-ca.crt") - if err := getHostFile(clientCAFile, info); err != nil { + if err := getHostFile(clientCAFile, "", info); err != nil { return nil, err } serverCAFile := filepath.Join(envInfo.DataDir, "server-ca.crt") - if err := getHostFile(serverCAFile, info); err != nil { + if err := getHostFile(serverCAFile, "", info); err != nil { return nil, err } @@ -331,12 +364,8 @@ func get(envInfo *cmds.Agent) (*config.Node, error) { } clientKubeletCert := filepath.Join(envInfo.DataDir, "client-kubelet.crt") - if err := getNodeNamedHostFile(clientKubeletCert, nodeName, newNodePasswordFile, info); err != nil { - return nil, err - } - clientKubeletKey := filepath.Join(envInfo.DataDir, "client-kubelet.key") - if err := getHostFile(clientKubeletKey, info); err != nil { + if err := getNodeNamedHostFile(clientKubeletCert, clientKubeletKey, nodeName, newNodePasswordFile, info); err != nil { return nil, err } @@ -346,12 +375,8 @@ func get(envInfo *cmds.Agent) (*config.Node, error) { } clientKubeProxyCert := filepath.Join(envInfo.DataDir, "client-kube-proxy.crt") - if err := getHostFile(clientKubeProxyCert, info); err != nil { - return nil, err - } - clientKubeProxyKey := filepath.Join(envInfo.DataDir, "client-kube-proxy.key") - if err := getHostFile(clientKubeProxyKey, info); err != nil { + if err := getHostFile(clientKubeProxyCert, clientKubeProxyKey, info); err != nil { return nil, err } @@ -361,12 +386,8 @@ func get(envInfo *cmds.Agent) (*config.Node, error) { } clientK3sControllerCert := filepath.Join(envInfo.DataDir, "client-k3s-controller.crt") - if err := getHostFile(clientK3sControllerCert, info); err != nil { - return nil, err - } - clientK3sControllerKey := filepath.Join(envInfo.DataDir, "client-k3s-controller.key") - if err := getHostFile(clientK3sControllerKey, info); err != nil { + if err := getHostFile(clientK3sControllerCert, clientK3sControllerKey, info); err != nil { return nil, err } diff --git a/pkg/server/router.go b/pkg/server/router.go index b9908b6dce..a3b9f23c5b 100644 --- a/pkg/server/router.go +++ b/pkg/server/router.go @@ -29,14 +29,10 @@ func router(serverConfig *config.Control, tunnel http.Handler, ca []byte) http.H authed := mux.NewRouter() authed.Use(authMiddleware(serverConfig, "k3s:agent")) authed.NotFoundHandler = serverConfig.Runtime.Handler - authed.Path("/v1-k3s/serving-kubelet.crt").Handler(servingKubeletCert(serverConfig)) - authed.Path("/v1-k3s/serving-kubelet.key").Handler(fileHandler(serverConfig.Runtime.ServingKubeletKey)) - authed.Path("/v1-k3s/client-kubelet.crt").Handler(clientKubeletCert(serverConfig)) - authed.Path("/v1-k3s/client-kubelet.key").Handler(fileHandler(serverConfig.Runtime.ClientKubeletKey)) - authed.Path("/v1-k3s/client-kube-proxy.crt").Handler(fileHandler(serverConfig.Runtime.ClientKubeProxyCert)) - authed.Path("/v1-k3s/client-kube-proxy.key").Handler(fileHandler(serverConfig.Runtime.ClientKubeProxyKey)) - authed.Path("/v1-k3s/client-k3s-controller.crt").Handler(fileHandler(serverConfig.Runtime.ClientK3sControllerCert)) - authed.Path("/v1-k3s/client-k3s-controller.key").Handler(fileHandler(serverConfig.Runtime.ClientK3sControllerKey)) + authed.Path("/v1-k3s/serving-kubelet.crt").Handler(servingKubeletCert(serverConfig, serverConfig.Runtime.ServingKubeletKey)) + authed.Path("/v1-k3s/client-kubelet.crt").Handler(clientKubeletCert(serverConfig, serverConfig.Runtime.ClientKubeletKey)) + authed.Path("/v1-k3s/client-kube-proxy.crt").Handler(fileHandler(serverConfig.Runtime.ClientKubeProxyCert, serverConfig.Runtime.ClientKubeProxyKey)) + authed.Path("/v1-k3s/client-k3s-controller.crt").Handler(fileHandler(serverConfig.Runtime.ClientK3sControllerCert, serverConfig.Runtime.ClientK3sControllerKey)) authed.Path("/v1-k3s/client-ca.crt").Handler(fileHandler(serverConfig.Runtime.ClientCA)) authed.Path("/v1-k3s/server-ca.crt").Handler(fileHandler(serverConfig.Runtime.ServerCA)) authed.Path("/v1-k3s/config").Handler(configHandler(serverConfig)) @@ -119,7 +115,7 @@ func getCACertAndKeys(caCertFile, caKeyFile, signingKeyFile string) ([]*x509.Cer return caCert, caKey.(crypto.Signer), key.(crypto.Signer), nil } -func servingKubeletCert(server *config.Control) http.Handler { +func servingKubeletCert(server *config.Control, keyFile string) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { if req.TLS == nil { resp.WriteHeader(http.StatusNotFound) @@ -156,11 +152,18 @@ func servingKubeletCert(server *config.Control) http.Handler { return } + keyBytes, err := ioutil.ReadFile(keyFile) + if err != nil { + http.Error(resp, err.Error(), http.StatusInternalServerError) + return + } + resp.Write(append(certutil.EncodeCertPEM(cert), certutil.EncodeCertPEM(caCert[0])...)) + resp.Write(keyBytes) }) } -func clientKubeletCert(server *config.Control) http.Handler { +func clientKubeletCert(server *config.Control, keyFile string) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { if req.TLS == nil { resp.WriteHeader(http.StatusNotFound) @@ -194,17 +197,39 @@ func clientKubeletCert(server *config.Control) http.Handler { return } + keyBytes, err := ioutil.ReadFile(keyFile) + if err != nil { + http.Error(resp, err.Error(), http.StatusInternalServerError) + return + } + resp.Write(append(certutil.EncodeCertPEM(cert), certutil.EncodeCertPEM(caCert[0])...)) + resp.Write(keyBytes) }) } -func fileHandler(fileName string) http.Handler { +func fileHandler(fileName ...string) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { if req.TLS == nil { resp.WriteHeader(http.StatusNotFound) return } - http.ServeFile(resp, req, fileName) + resp.Header().Set("Content-Type", "text/plain") + + if len(fileName) == 1 { + http.ServeFile(resp, req, fileName[0]) + return + } + + for _, f := range fileName { + bytes, err := ioutil.ReadFile(f) + if err != nil { + logrus.Errorf("Failed to read %s: %v", f, err) + resp.WriteHeader(http.StatusInternalServerError) + return + } + resp.Write(bytes) + } }) }