diff --git a/pkg/clientaccess/token.go b/pkg/clientaccess/token.go index e0f1586adf..c914881f87 100644 --- a/pkg/clientaccess/token.go +++ b/pkg/clientaccess/token.go @@ -13,10 +13,11 @@ import ( "time" "github.com/pkg/errors" + "github.com/sirupsen/logrus" ) var ( - defaultClientTimeout = 20 * time.Second + defaultClientTimeout = 10 * time.Second defaultClient = &http.Client{ Timeout: defaultClientTimeout, @@ -32,8 +33,9 @@ var ( ) const ( - tokenPrefix = "K10" - tokenFormat = "%s%s::%s:%s" + tokenPrefix = "K10" + tokenFormat = "%s%s::%s:%s" + caHashLength = sha256.Size * 2 ) type OverrideURLCallback func(config []byte) (*url.URL, error) @@ -59,16 +61,10 @@ func ParseAndValidateToken(server string, token string) (*Info, error) { return nil, err } - if err := info.setServer(server); err != nil { + if err := info.setAndValidateServer(server); err != nil { return nil, err } - if info.caHash != "" { - if err := info.validateCAHash(); err != nil { - return nil, err - } - } - return info, nil } @@ -82,26 +78,24 @@ func ParseAndValidateTokenForUser(server string, token string, username string) info.Username = username - if err := info.setServer(server); err != nil { + if err := info.setAndValidateServer(server); err != nil { return nil, err } - if info.caHash != "" { - if err := info.validateCAHash(); err != nil { - return nil, err - } - } - return info, nil } +// setAndValidateServer updates the remote server's cert info, and validates it against the provided hash +func (info *Info) setAndValidateServer(server string) error { + if err := info.setServer(server); err != nil { + return err + } + return info.validateCAHash() +} + // validateCACerts returns a boolean indicating whether or not a CA bundle matches the provided hash, // and a string containing the hash of the CA bundle. func validateCACerts(cacerts []byte, hash string) (bool, string) { - if len(cacerts) == 0 && hash == "" { - return true, "" - } - newHash := hashCA(cacerts) return hash == newHash, newHash } @@ -126,6 +120,10 @@ func ParseUsernamePassword(token string) (string, string, bool) { func parseToken(token string) (*Info, error) { var info = &Info{} + if len(token) == 0 { + return nil, errors.New("token must not be empty") + } + if !strings.HasPrefix(token, tokenPrefix) { token = fmt.Sprintf(tokenFormat, tokenPrefix, "", "", token) } @@ -136,13 +134,17 @@ func parseToken(token string) (*Info, error) { parts := strings.SplitN(token, "::", 2) token = parts[0] if len(parts) > 1 { + hashLen := len(parts[0]) + if hashLen > 0 && hashLen != caHashLength { + return nil, errors.New("invalid token CA hash length") + } info.caHash = parts[0] token = parts[1] } parts = strings.SplitN(token, ":", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid token format") + if len(parts) != 2 || len(parts[1]) == 0 { + return nil, errors.New("invalid token format") } info.Username = parts[0] @@ -212,10 +214,20 @@ func (info *Info) setServer(server string) error { // ValidateCAHash validates that info's caHash matches the CACerts hash. func (info *Info) validateCAHash() error { - if ok, serverHash := validateCACerts(info.CACerts, info.caHash); !ok { - return fmt.Errorf("token CA hash does not match the server CA hash: %s != %s", info.caHash, serverHash) + if len(info.caHash) > 0 && len(info.CACerts) == 0 { + // Warn if the user provided a CA hash but we're not going to validate because it's already trusted + logrus.Warn("Cluster CA certificate is trusted by the host CA bundle. " + + "Token CA hash will not be validated.") + } else if len(info.caHash) == 0 && len(info.CACerts) > 0 { + // Warn if the CA is self-signed but the user didn't provide a hash to validate it against + logrus.Warn("Cluster CA certificate is not trusted by the host CA bundle, but the token does not include a CA hash. " + + "Use the full token from the server's node-token file to enable Cluster CA validation.") + } else if len(info.CACerts) > 0 && len(info.caHash) > 0 { + // only verify CA hash if the server cert is not trusted by the OS CA bundle + if ok, serverHash := validateCACerts(info.CACerts, info.caHash); !ok { + return fmt.Errorf("token CA hash does not match the Cluster CA certificate hash: %s != %s", info.caHash, serverHash) + } } - return nil } diff --git a/pkg/clientaccess/token_test.go b/pkg/clientaccess/token_test.go new file mode 100644 index 0000000000..df63e62bc2 --- /dev/null +++ b/pkg/clientaccess/token_test.go @@ -0,0 +1,401 @@ +package clientaccess + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "net" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/rancher/dynamiclistener/cert" + "github.com/rancher/dynamiclistener/factory" + "github.com/rancher/k3s/pkg/bootstrap" + "github.com/rancher/k3s/pkg/daemons/config" + "github.com/stretchr/testify/assert" +) + +var ( + defaultUsername = "server" + defaultPassword = "token" +) + +// TestTrustedCA confirms that tokens are validated when the server uses a cert (self-signed or otherwise) +// that is trusted by the OS CA bundle. This test must be run first, since it mucks with the system root certs. +func TestTrustedCA(t *testing.T) { + assert := assert.New(t) + server := newTLSServer(t, defaultUsername, defaultPassword, false) + defer server.Close() + + testInfo := &Info{ + CACerts: getServerCA(server), + BaseURL: server.URL, + Username: defaultUsername, + Password: defaultPassword, + caHash: hashCA(getServerCA(server)), + } + + testCases := []struct { + token string + expected string + }{ + {defaultPassword, ""}, + {testInfo.String(), testInfo.Username}, + } + + // Point OS CA bundle at this test's CA cert to simulate a trusted CA cert. + // Note that this only works if the OS CA bundle has not yet been loaded in this process, + // as it is cached for the duration of the process lifetime. + // Ref: https://github.com/golang/go/issues/41888 + path := t.TempDir() + "/ca.crt" + writeServerCA(server, path) + os.Setenv("SSL_CERT_FILE", path) + + for _, testCase := range testCases { + info, err := ParseAndValidateToken(server.URL, testCase.token) + if assert.NoError(err, testCase) { + assert.Nil(info.CACerts, testCase) + assert.Equal(testCase.expected, info.Username, testCase.token) + } + + info, err = ParseAndValidateTokenForUser(server.URL, testCase.token, "agent") + if assert.NoError(err, testCase) { + assert.Nil(info.CACerts, testCase) + assert.Equal("agent", info.Username, testCase) + } + } + + // Confirm that the cert is actually trusted by the OS CA bundle by making a request + // with empty cert pool + testInfo.CACerts = nil + res, err := Get("/v1-k3s/server-bootstrap", testInfo) + assert.NoError(err) + assert.NotEmpty(res) +} + +// TestUntrustedCA confirms that tokens are validated when the server uses a self-signed cert +// that is NOT trusted by the OS CA bundle. +func TestUntrustedCA(t *testing.T) { + assert := assert.New(t) + server := newTLSServer(t, defaultUsername, defaultPassword, false) + defer server.Close() + + testInfo := &Info{ + CACerts: getServerCA(server), + BaseURL: server.URL, + Username: defaultUsername, + Password: defaultPassword, + caHash: hashCA(getServerCA(server)), + } + + testCases := []struct { + token string + expected string + }{ + {defaultPassword, ""}, + {testInfo.String(), testInfo.Username}, + } + + for _, testCase := range testCases { + info, err := ParseAndValidateToken(server.URL, testCase.token) + if assert.NoError(err, testCase) { + assert.Equal(testInfo.CACerts, info.CACerts, testCase) + assert.Equal(testCase.expected, info.Username, testCase) + } + + info, err = ParseAndValidateTokenForUser(server.URL, testCase.token, "agent") + if assert.NoError(err, testCase) { + assert.Equal(testInfo.CACerts, info.CACerts, testCase) + assert.Equal("agent", info.Username, testCase) + } + } +} + +// TestInvalidServers tests that invalid server URLs are properly rejected +func TestInvalidServers(t *testing.T) { + assert := assert.New(t) + testCases := []struct { + server string + token string + expected string + }{ + {" https://localhost:6443", "token", "Invalid server url, failed to parse: https://localhost:6443: parse \" https://localhost:6443\": first path segment in URL cannot contain colon"}, + {"http://localhost:6443", "token", "only https:// URLs are supported, invalid scheme: http://localhost:6443"}, + } + + for _, testCase := range testCases { + _, err := ParseAndValidateToken(testCase.server, testCase.token) + assert.EqualError(err, testCase.expected, testCase) + + _, err = ParseAndValidateTokenForUser(testCase.server, testCase.token, defaultUsername) + assert.EqualError(err, testCase.expected, testCase) + } +} + +// TestInvalidTokens tests that tokens which are empty, invalid, or incorrect are properly rejected +func TestInvalidTokens(t *testing.T) { + assert := assert.New(t) + server := newTLSServer(t, defaultUsername, defaultPassword, false) + defer server.Close() + + testCases := []struct { + server string + token string + expected string + }{ + {server.URL, "", "token must not be empty"}, + {server.URL, "K10::", "invalid token format"}, + {server.URL, "K10::x", "invalid token format"}, + {server.URL, "K10::x:", "invalid token format"}, + {server.URL, "K10XX::x:y", "invalid token CA hash length"}, + {server.URL, + "K10XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX::x:y", + "token CA hash does not match the Cluster CA certificate hash: XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX != " + hashCA(getServerCA(server))}, + } + + for _, testCase := range testCases { + info, err := ParseAndValidateToken(testCase.server, testCase.token) + assert.EqualError(err, testCase.expected, testCase) + assert.Nil(info, testCase) + + info, err = ParseAndValidateTokenForUser(testCase.server, testCase.token, defaultUsername) + assert.EqualError(err, testCase.expected, testCase) + assert.Nil(info, testCase) + } +} + +// TestInvalidCredentials tests that tokens which don't have valid credentials are rejected +func TestInvalidCredentials(t *testing.T) { + assert := assert.New(t) + server := newTLSServer(t, defaultUsername, defaultPassword, false) + defer server.Close() + + testInfo := &Info{ + CACerts: getServerCA(server), + BaseURL: server.URL, + Username: "nobody", + Password: "invalid", + caHash: hashCA(getServerCA(server)), + } + + testCases := []string{ + testInfo.Password, + testInfo.String(), + } + + for _, testCase := range testCases { + info, err := ParseAndValidateToken(server.URL, testCase) + assert.NoError(err, testCase) + if assert.NotNil(info) { + res, err := Get("/v1-k3s/server-bootstrap", info) + assert.Error(err, testCase) + assert.Empty(res, testCase) + } + + info, err = ParseAndValidateTokenForUser(server.URL, testCase, defaultUsername) + assert.NoError(err, testCase) + if assert.NotNil(info) { + res, err := Get("/v1-k3s/server-bootstrap", info) + assert.Error(err, testCase) + assert.Empty(res, testCase) + } + } +} + +// TestWrongCert tests that errors are returned when the server's cert isn't issued by its CA +func TestWrongCert(t *testing.T) { + assert := assert.New(t) + server := newTLSServer(t, defaultUsername, defaultPassword, true) + defer server.Close() + + info, err := ParseAndValidateToken(server.URL, defaultPassword) + assert.Error(err) + assert.Nil(info) + + info, err = ParseAndValidateTokenForUser(server.URL, defaultPassword, defaultUsername) + assert.Error(err) + assert.Nil(info) +} + +// TestConnectionFailures tests that connections are timed out properly +func TestConnectionFailures(t *testing.T) { + testDuration := (defaultClientTimeout * 2) + time.Second + assert := assert.New(t) + testCases := []struct { + server string + token string + }{ + {"https://192.0.2.1:6443", "token"}, // RFC 5735 TEST-NET-1 for use in documentation and example code + {"https://localhost:1", "token"}, + } + + for _, testCase := range testCases { + startTime := time.Now() + info, err := ParseAndValidateToken(testCase.server, testCase.token) + assert.Error(err, testCase) + assert.Nil(info, testCase) + assert.WithinDuration(time.Now(), startTime, testDuration, testCase) + + startTime = time.Now() + info, err = ParseAndValidateTokenForUser(testCase.server, testCase.token, defaultUsername) + assert.Error(err, testCase) + assert.Nil(info, testCase) + assert.WithinDuration(startTime, time.Now(), testDuration, testCase) + } +} + +// TestUserPass tests that usernames and passwords are parsed or not parsed from token strings +func TestUserPass(t *testing.T) { + assert := assert.New(t) + testCases := []struct { + token string + username string + password string + expect bool + }{ + {"K10XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX::username:password", "username", "password", true}, + {"password", "", "password", true}, + {"K10X::x", "", "", false}, + } + + for _, testCase := range testCases { + username, password, ok := ParseUsernamePassword(testCase.token) + assert.Equal(testCase.expect, ok, testCase) + if ok { + assert.Equal(testCase.username, username, testCase) + assert.Equal(testCase.password, password, testCase) + } + } +} + +// TestParseAndGet tests URL handling along some hard-to-reach code paths +func TestParseAndGet(t *testing.T) { + assert := assert.New(t) + server := newTLSServer(t, defaultUsername, defaultPassword, false) + defer server.Close() + + testCases := []struct { + extraBasePre string + extraBasePost string + path string + parseFail bool + getFail bool + }{ + {"/", "", "/cacerts", false, false}, + {"/%2", "", "/cacerts", true, false}, + {"", "", "/%2", false, true}, + {"", "/%2", "/cacerts", false, true}, + } + + for _, testCase := range testCases { + info, err := ParseAndValidateTokenForUser(server.URL+testCase.extraBasePre, defaultPassword, defaultUsername) + // Check for expected error when parsing server + token + if testCase.parseFail { + assert.Error(err, testCase) + } else if assert.NoError(err, testCase) { + info.BaseURL = server.URL + testCase.extraBasePost + _, err := Get(testCase.path, info) + // Check for expected error when making Get request + if testCase.getFail { + assert.Error(err, testCase) + } else { + assert.NoError(err, testCase) + } + } + } +} + +// newTLSServer returns a HTTPS server that mocks the basic functionality required to validate K3s join tokens. +// Each call to this function will generate new CA and server certificates unique to the returned server. +func newTLSServer(t *testing.T, username, password string, sendWrongCA bool) *httptest.Server { + var server *httptest.Server + server = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1-k3s/server-bootstrap" { + if authUsername, authPassword, ok := r.BasicAuth(); ok != true || authPassword != password || authUsername != username { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + bootstrapData := &config.ControlRuntimeBootstrap{} + w.Header().Set("Content-Type", "application/json") + if err := bootstrap.Write(w, bootstrapData); err != nil { + t.Errorf("failed to write bootstrap: %v", err) + } + return + } + + if r.URL.Path == "/cacerts" { + w.Header().Set("Content-Type", "text/plain") + if _, err := w.Write(getServerCA(server)); err != nil { + t.Errorf("Failed to write cacerts: %v", err) + } + return + } + + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + })) + + // Create new CA cert and key + caCert, caKey, err := factory.GenCA() + if err != nil { + t.Fatal(err) + } + + // Generate new server cert; reuse the key from the CA + cfg := cert.Config{ + CommonName: "localhost", + Organization: []string{"testing"}, + Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + AltNames: cert.AltNames{ + DNSNames: []string{"localhost"}, + IPs: []net.IP{net.IPv4(127, 0, 0, 1)}, + }, + } + serverCert, err := cert.NewSignedCert(cfg, caKey, caCert, caKey) + if err != nil { + t.Fatal(err) + } + + // Bind server and CA certs into chain for TLS listener configuration + server.TLS = &tls.Config{} + server.TLS.Certificates = []tls.Certificate{ + {Certificate: [][]byte{serverCert.Raw}, Leaf: serverCert, PrivateKey: caKey}, + {Certificate: [][]byte{caCert.Raw}, Leaf: caCert}, + } + + if sendWrongCA { + // Create new CA cert and key and use that as the CA cert instead of the one that actually signed the server cert + badCert, _, err := factory.GenCA() + if err != nil { + t.Fatal(err) + } + server.TLS.Certificates[1].Certificate[0] = badCert.Raw + server.TLS.Certificates[1].Leaf = badCert + } + + server.StartTLS() + return server +} + +// getServerCA returns a byte slice containing the PEM encoding of the server's CA certificate +func getServerCA(server *httptest.Server) []byte { + certLen := len(server.TLS.Certificates) + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.TLS.Certificates[certLen-1].Certificate[0]}) +} + +// writeServerCA writes the PEM-encoded server certificate to a given path +func writeServerCA(server *httptest.Server, path string) error { + certOut, err := os.Create(path) + if err != nil { + return err + } + defer certOut.Close() + + if _, err := certOut.Write(getServerCA(server)); err != nil { + return err + } + + return nil +}