mirror of
https://github.com/k3s-io/k3s.git
synced 2024-06-07 19:41:36 +00:00
137 lines
3.9 KiB
Go
137 lines
3.9 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/canonical/go-dqlite/internal/protocol"
|
|
_ "github.com/mattn/go-sqlite3" // Go SQLite bindings
|
|
)
|
|
|
|
// NodeStore is used by a dqlite client to get an initial list of candidate
|
|
// dqlite nodes that it can dial in order to find a leader dqlite node to use.
|
|
type NodeStore = protocol.NodeStore
|
|
|
|
// NodeInfo holds information about a single server.
|
|
type NodeInfo = protocol.NodeInfo
|
|
|
|
// InmemNodeStore keeps the list of target dqlite nodes in memory.
|
|
type InmemNodeStore = protocol.InmemNodeStore
|
|
|
|
// NewInmemNodeStore creates NodeStore which stores its data in-memory.
|
|
var NewInmemNodeStore = protocol.NewInmemNodeStore
|
|
|
|
// DatabaseNodeStore persists a list addresses of dqlite nodes in a SQL table.
|
|
type DatabaseNodeStore struct {
|
|
db *sql.DB // Database handle to use.
|
|
schema string // Name of the schema holding the servers table.
|
|
table string // Name of the servers table.
|
|
column string // Column name in the servers table holding the server address.
|
|
}
|
|
|
|
// DefaultNodeStore creates a new NodeStore using the given filename to
|
|
// open a SQLite database, with default names for the schema, table and column
|
|
// parameters.
|
|
//
|
|
// It also creates the table if it doesn't exist yet.
|
|
func DefaultNodeStore(filename string) (*DatabaseNodeStore, error) {
|
|
// Open the database.
|
|
db, err := sql.Open("sqlite3", filename)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to open database")
|
|
}
|
|
|
|
// Since we're setting SQLite single-thread mode, we need to have one
|
|
// connection at most.
|
|
db.SetMaxOpenConns(1)
|
|
|
|
// Create the servers table if it does not exist yet.
|
|
_, err = db.Exec("CREATE TABLE IF NOT EXISTS servers (address TEXT, UNIQUE(address))")
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create servers table")
|
|
}
|
|
|
|
store := NewNodeStore(db, "main", "servers", "address")
|
|
|
|
return store, nil
|
|
}
|
|
|
|
// NewNodeStore creates a new NodeStore.
|
|
func NewNodeStore(db *sql.DB, schema, table, column string) *DatabaseNodeStore {
|
|
return &DatabaseNodeStore{
|
|
db: db,
|
|
schema: schema,
|
|
table: table,
|
|
column: column,
|
|
}
|
|
}
|
|
|
|
// Get the current servers.
|
|
func (d *DatabaseNodeStore) Get(ctx context.Context) ([]NodeInfo, error) {
|
|
tx, err := d.db.Begin()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to begin transaction")
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
query := fmt.Sprintf("SELECT %s FROM %s.%s", d.column, d.schema, d.table)
|
|
rows, err := tx.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to query servers table")
|
|
}
|
|
defer rows.Close()
|
|
|
|
servers := make([]NodeInfo, 0)
|
|
for rows.Next() {
|
|
var address string
|
|
err := rows.Scan(&address)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to fetch server address")
|
|
}
|
|
servers = append(servers, NodeInfo{ID: 1, Address: address})
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, errors.Wrap(err, "result set failure")
|
|
}
|
|
|
|
return servers, nil
|
|
}
|
|
|
|
// Set the servers addresses.
|
|
func (d *DatabaseNodeStore) Set(ctx context.Context, servers []NodeInfo) error {
|
|
tx, err := d.db.Begin()
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to begin transaction")
|
|
}
|
|
|
|
query := fmt.Sprintf("DELETE FROM %s.%s", d.schema, d.table)
|
|
if _, err := tx.ExecContext(ctx, query); err != nil {
|
|
tx.Rollback()
|
|
return errors.Wrap(err, "failed to delete existing servers rows")
|
|
}
|
|
|
|
query = fmt.Sprintf("INSERT INTO %s.%s(%s) VALUES (?)", d.schema, d.table, d.column)
|
|
stmt, err := tx.PrepareContext(ctx, query)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return errors.Wrap(err, "failed to prepare insert statement")
|
|
}
|
|
defer stmt.Close()
|
|
|
|
for _, server := range servers {
|
|
if _, err := stmt.ExecContext(ctx, server.Address); err != nil {
|
|
tx.Rollback()
|
|
return errors.Wrapf(err, "failed to insert server %s", server.Address)
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return errors.Wrap(err, "failed to commit transaction")
|
|
}
|
|
|
|
return nil
|
|
}
|