Update vendored kvsql

This commit is contained in:
Erik Wilson 2019-06-15 12:18:08 -07:00
parent 32283aa8e2
commit 043da2e539
7 changed files with 159 additions and 61 deletions

View File

@ -142,7 +142,7 @@ import:
- package: github.com/hashicorp/golang-lru
version: v0.5.0
- package: github.com/ibuildthecloud/kvsql
version: 1afc2d8ad7d7e263c1971b05cb37e83aa5562561
version: 79f1f6881e28b90976f070aad6edad8e259057c1
repo: https://github.com/erikwilson/rancher-kvsql.git
- package: github.com/imdario/mergo
version: v0.3.5

View File

@ -122,8 +122,7 @@ golang.org/x/oauth2 a6bd8cefa1811bd24b86f8902872e4e8225f74c4
golang.org/x/time f51c12702a4d776e4c1fa9b0fabab841babae631
gopkg.in/inf.v0 3887ee99ecf07df5b447e9b00d9c0b2adaa9f3e4
gopkg.in/yaml.v2 v2.2.1
#github.com/ibuildthecloud/kvsql 788464096f5af361d166858efccf26c12dc5b427
github.com/ibuildthecloud/kvsql 1afc2d8ad7d7e263c1971b05cb37e83aa5562561 https://github.com/erikwilson/rancher-kvsql.git
github.com/ibuildthecloud/kvsql 79f1f6881e28b90976f070aad6edad8e259057c1 https://github.com/erikwilson/rancher-kvsql.git
# rootless
github.com/rootless-containers/rootlesskit v0.4.1

View File

@ -18,6 +18,7 @@ import (
"crypto/tls"
"time"
"github.com/coreos/etcd/pkg/transport"
"google.golang.org/grpc"
)
@ -39,4 +40,6 @@ type Config struct {
DialTimeout time.Duration
DialOptions []grpc.DialOption
TLSInfo *transport.TLSInfo
}

View File

@ -5,10 +5,16 @@ import (
"database/sql"
"strings"
"github.com/coreos/etcd/pkg/transport"
"github.com/go-sql-driver/mysql"
"github.com/ibuildthecloud/kvsql/clientv3/driver"
)
const (
defaultUnixDSN = "root@unix(/var/run/mysqld/mysqld.sock)/"
defaultHostDSN = "root@tcp(127.0.0.1)/"
)
var (
fieldList = "name, value, old_value, old_revision, create_revision, revision, ttl, version, del"
baseList = `
@ -46,7 +52,7 @@ INSERT INTO key_value(` + fieldList + `)
}
nameIdx = "create index name_idx on key_value (name(100))"
revisionIdx = "create index revision_idx on key_value (revision)"
createDB = "create database if not exists kubernetes"
createDB = "create database if not exists "
)
func NewMySQL() *driver.Generic {
@ -65,31 +71,24 @@ func NewMySQL() *driver.Generic {
}
}
func Open(dataSourceName string, tlsConfig *tls.Config) (*sql.DB, error) {
if dataSourceName == "" {
dataSourceName = "root@unix(/var/run/mysqld/mysqld.sock)/"
func Open(dataSourceName string, tlsInfo *transport.TLSInfo) (*sql.DB, error) {
tlsConfig, err := tlsInfo.ClientConfig()
if err != nil {
return nil, err
}
// get database name
dsList := strings.Split(dataSourceName, "/")
databaseName := dsList[len(dsList)-1]
if databaseName == "" {
if err := createDBIfNotExist(dataSourceName); err != nil {
return nil, err
}
dataSourceName = dataSourceName + "kubernetes"
tlsConfig.MinVersion = tls.VersionTLS11
if len(tlsInfo.CertFile) == 0 && len(tlsInfo.KeyFile) == 0 && len(tlsInfo.CAFile) == 0 {
tlsConfig = nil
}
parsedDSN, err := prepareDSN(dataSourceName, tlsConfig)
if err != nil {
return nil, err
}
if err := createDBIfNotExist(parsedDSN); err != nil {
return nil, err
}
// setting up tlsConfig
if tlsConfig != nil {
mysql.RegisterTLSConfig("custom", tlsConfig)
if strings.Contains(dataSourceName, "?") {
dataSourceName = dataSourceName + ",tls=custom"
} else {
dataSourceName = dataSourceName + "?tls=custom"
}
}
db, err := sql.Open("mysql", dataSourceName)
db, err := sql.Open("mysql", parsedDSN)
if err != nil {
return nil, err
}
@ -116,13 +115,30 @@ func Open(dataSourceName string, tlsConfig *tls.Config) (*sql.DB, error) {
}
func createDBIfNotExist(dataSourceName string) error {
config, err := mysql.ParseDSN(dataSourceName)
if err != nil {
return err
}
dbName := config.DBName
db, err := sql.Open("mysql", dataSourceName)
if err != nil {
return err
}
_, err = db.Exec(createDB)
_, err = db.Exec(createDB + dbName)
if err != nil {
return err
if mysqlError, ok := err.(*mysql.MySQLError); !ok || mysqlError.Number != 1049 {
return err
}
config.DBName = ""
db, err = sql.Open("mysql", config.FormatDSN())
if err != nil {
return err
}
_, err = db.Exec(createDB + dbName)
if err != nil {
return err
}
}
return nil
}
@ -130,11 +146,35 @@ func createDBIfNotExist(dataSourceName string) error {
func createIndex(db *sql.DB, indexStmt string) error {
_, err := db.Exec(indexStmt)
if err != nil {
// check if its a duplicate error
if err.(*mysql.MySQLError).Number == 1061 {
return nil
if mysqlError, ok := err.(*mysql.MySQLError); !ok || mysqlError.Number != 1061 {
return err
}
return err
}
return nil
}
func prepareDSN(dataSourceName string, tlsConfig *tls.Config) (string, error) {
if len(dataSourceName) == 0 {
dataSourceName = defaultUnixDSN
if tlsConfig != nil {
dataSourceName = defaultHostDSN
}
}
config, err := mysql.ParseDSN(dataSourceName)
if err != nil {
return "", err
}
// setting up tlsConfig
if tlsConfig != nil {
mysql.RegisterTLSConfig("custom", tlsConfig)
config.TLSConfig = "custom"
}
dbName := "kubernetes"
if len(config.DBName) > 0 {
dbName = config.DBName
}
config.DBName = dbName
parsedDSN := config.FormatDSN()
return parsedDSN, nil
}

View File

@ -2,14 +2,20 @@ package pgsql
import (
"database/sql"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/coreos/etcd/pkg/transport"
"github.com/ibuildthecloud/kvsql/clientv3/driver"
"github.com/lib/pq"
)
const (
defaultDSN = "postgres://postgres:postgres@localhost/"
)
var (
fieldList = "name, value, old_value, old_revision, create_revision, revision, ttl, version, del"
baseList = `
@ -46,7 +52,7 @@ INSERT INTO key_value(` + fieldList + `)
`create index if not exists name_idx on key_value (name)`,
`create index if not exists revision_idx on key_value (revision)`,
}
createDB = "create database kubernetes"
createDB = "create database "
)
func NewPGSQL() *driver.Generic {
@ -65,22 +71,16 @@ func NewPGSQL() *driver.Generic {
}
}
func Open(dataSourceName string) (*sql.DB, error) {
if dataSourceName == "" {
dataSourceName = "postgres://postgres:postgres@localhost/"
} else {
dataSourceName = "postgres://" + dataSourceName
func Open(dataSourceName string, tlsInfo *transport.TLSInfo) (*sql.DB, error) {
parsedDSN, err := prepareDSN(dataSourceName, tlsInfo)
if err != nil {
return nil, err
}
// get database name
dsList := strings.Split(dataSourceName, "/")
databaseName := dsList[len(dsList)-1]
if databaseName == "" {
if err := createDBIfNotExist(dataSourceName); err != nil {
return nil, err
}
dataSourceName = dataSourceName + "kubernetes"
if err := createDBIfNotExist(parsedDSN); err != nil {
return nil, err
}
db, err := sql.Open("postgres", dataSourceName)
db, err := sql.Open("postgres", parsedDSN)
if err != nil {
return nil, err
}
@ -96,15 +96,35 @@ func Open(dataSourceName string) (*sql.DB, error) {
}
func createDBIfNotExist(dataSourceName string) error {
u, err := url.Parse(dataSourceName)
if err != nil {
return err
}
dbName := strings.SplitN(u.Path, "/", 2)[1]
db, err := sql.Open("postgres", dataSourceName)
if err != nil {
return err
}
_, err = db.Exec(createDB)
err = db.Ping()
// check if database already exists
if err != nil && err.(*pq.Error).Code != "42P04" {
if _, ok := err.(*pq.Error); !ok {
return err
}
if err := err.(*pq.Error); err.Code != "42P04" {
if err.Code != "3D000" {
return err
}
// database doesn't exit, will try to create it
u.Path = "/postgres"
db, err := sql.Open("postgres", u.String())
if err != nil {
return err
}
_, err = db.Exec(createDB + dbName + ";")
if err != nil {
return err
}
}
return nil
}
@ -117,3 +137,46 @@ func q(sql string) string {
return pref + strconv.Itoa(n)
})
}
func prepareDSN(dataSourceName string, tlsInfo *transport.TLSInfo) (string, error) {
if len(dataSourceName) == 0 {
dataSourceName = defaultDSN
} else {
dataSourceName = "postgres://" + dataSourceName
}
u, err := url.Parse(dataSourceName)
if err != nil {
return "", err
}
if len(u.Path) == 0 || u.Path == "/" {
u.Path = "/kubernetes"
}
queryMap, err := url.ParseQuery(u.RawQuery)
if err != nil {
return "", err
}
// set up tls dsn
params := url.Values{}
sslmode := "require"
if _, ok := queryMap["sslcert"]; tlsInfo.CertFile != "" && !ok {
params.Add("sslcert", tlsInfo.CertFile)
sslmode = "verify-full"
}
if _, ok := queryMap["sslkey"]; tlsInfo.KeyFile != "" && !ok {
params.Add("sslkey", tlsInfo.KeyFile)
sslmode = "verify-full"
}
if _, ok := queryMap["sslrootcert"]; tlsInfo.CAFile != "" && !ok {
params.Add("sslrootcert", tlsInfo.CAFile)
sslmode = "verify-full"
}
if _, ok := queryMap["sslmode"]; !ok {
params.Add("sslmode", sslmode)
}
for k, v := range queryMap {
params.Add(k, v[0])
}
u.RawQuery = params.Encode()
return u.String(), nil
}

View File

@ -115,15 +115,17 @@ func newKV(cfg Config) (*kv, error) {
}
driver = sqlite.NewSQLite()
case "mysql":
if db, err = mysql.Open(parts[1], cfg.TLS); err != nil {
if db, err = mysql.Open(parts[1], cfg.TLSInfo); err != nil {
return nil, err
}
driver = mysql.NewMySQL()
case "postgres":
if db, err = pgsql.Open(parts[1]); err != nil {
if db, err = pgsql.Open(parts[1], cfg.TLSInfo); err != nil {
return nil, err
}
driver = pgsql.NewPGSQL()
default:
return nil, fmt.Errorf("unknown driver type [%s]", parts[0])
}
if err := driver.Start(context.TODO(), db); err != nil {

View File

@ -18,7 +18,6 @@ package factory
import (
"context"
"crypto/tls"
"fmt"
"sync/atomic"
"time"
@ -67,22 +66,14 @@ func NewKVSQLHealthCheck(c storagebackend.Config) (func() error, error) {
}
func newETCD3Client(c storagebackend.Config) (*clientv3.Client, error) {
tlsInfo := transport.TLSInfo{
tlsInfo := &transport.TLSInfo{
CertFile: c.Transport.CertFile,
KeyFile: c.Transport.KeyFile,
CAFile: c.Transport.CAFile,
}
tlsConfig, err := tlsInfo.ClientConfig()
if err != nil {
return nil, err
}
tlsConfig.MinVersion = tls.VersionTLS11
if len(c.Transport.CertFile) == 0 && len(c.Transport.KeyFile) == 0 && len(c.Transport.CAFile) == 0 {
tlsConfig = nil
}
cfg := clientv3.Config{
Endpoints: c.Transport.ServerList,
TLS: tlsConfig,
TLSInfo: tlsInfo,
}
if len(cfg.Endpoints) == 0 {