k3s/pkg/agent/loadbalancer/loadbalancer_test.go
2019-07-30 09:53:15 -07:00

184 lines
3.2 KiB
Go

package loadbalancer
import (
"bufio"
"context"
"errors"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/rancher/k3s/pkg/cli/cmds"
)
type server struct {
listener net.Listener
conns []net.Conn
prefix string
}
func createServer(prefix string) (*server, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
}
s := &server{
prefix: prefix,
listener: listener,
}
go s.serve()
return s, nil
}
func (s *server) serve() {
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
s.conns = append(s.conns, conn)
go s.echo(conn)
}
}
func (s *server) close() {
s.listener.Close()
for _, conn := range s.conns {
conn.Close()
}
}
func (s *server) echo(conn net.Conn) {
for {
result, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
return
}
conn.Write([]byte(s.prefix + ":" + result))
}
}
func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
return "", err
}
return strings.TrimSpace(result), nil
}
func assertEqual(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Fatalf("[ %v != %v ]", a, b)
}
}
func assertNotEqual(t *testing.T, a interface{}, b interface{}) {
if a == b {
t.Fatalf("[ %v == %v ]", a, b)
}
}
func TestFailOver(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "lb-test")
if err != nil {
assertEqual(t, err, nil)
}
defer os.RemoveAll(tmpDir)
ogServe, err := createServer("og")
if err != nil {
assertEqual(t, err, nil)
}
lbServe, err := createServer("lb")
if err != nil {
assertEqual(t, err, nil)
}
cfg := cmds.Agent{
ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()),
DataDir: tmpDir,
}
lb, err := Setup(context.Background(), cfg)
if err != nil {
assertEqual(t, err, nil)
}
parsedURL, err := url.Parse(lb.LoadBalancerServerURL())
if err != nil {
assertEqual(t, err, nil)
}
localAddress := parsedURL.Host
lb.Update([]string{lbServe.listener.Addr().String()})
conn1, err := net.Dial("tcp", localAddress)
if err != nil {
assertEqual(t, err, nil)
}
result1, err := ping(conn1)
if err != nil {
assertEqual(t, err, nil)
}
assertEqual(t, result1, "lb:ping")
lbServe.close()
_, err = ping(conn1)
assertNotEqual(t, err, nil)
conn2, err := net.Dial("tcp", localAddress)
if err != nil {
assertEqual(t, err, nil)
}
result2, err := ping(conn2)
if err != nil {
assertEqual(t, err, nil)
}
assertEqual(t, result2, "og:ping")
}
func TestFailFast(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "lb-test")
if err != nil {
assertEqual(t, err, nil)
}
defer os.RemoveAll(tmpDir)
cfg := cmds.Agent{
ServerURL: "http://127.0.0.1:-1/",
DataDir: tmpDir,
}
lb, err := Setup(context.Background(), cfg)
if err != nil {
assertEqual(t, err, nil)
}
conn, err := net.Dial("tcp", lb.localAddress)
if err != nil {
assertEqual(t, err, nil)
}
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(10 * time.Millisecond)
select {
case err := <-done:
assertNotEqual(t, err, nil)
case <-timeout:
t.Fatal(errors.New("time out"))
}
}