mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
fdb45153fe
* feat(llama.cpp): Enable decentralized, distributed inference As https://github.com/mudler/LocalAI/pull/2324 introduced distributed inferencing thanks to @rgerganov implementation in https://github.com/ggerganov/llama.cpp/pull/6829 in upstream llama.cpp, now it is possible to distribute the workload to remote llama.cpp gRPC server. This changeset now uses mudler/edgevpn to establish a secure, distributed network between the nodes using a shared token. The token is generated automatically when starting the server with the `--p2p` flag, and can be used by starting the workers with `local-ai worker p2p-llama-cpp-rpc` by passing the token via environment variable (TOKEN) or with args (--token). As per how mudler/edgevpn works, a network is established between the server and the workers with dht and mdns discovery protocols, the llama.cpp rpc server is automatically started and exposed to the underlying p2p network so the API server can connect on. When the HTTP server is started, it will discover the workers in the network and automatically create the port-forwards to the service locally. Then llama.cpp is configured to use the services. This feature is behind the "p2p" GO_FLAGS Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * go mod tidy Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * ci: add p2p tag Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * better message Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
315 lines
7.7 KiB
Go
315 lines
7.7 KiB
Go
//go:build p2p
|
|
// +build p2p
|
|
|
|
package p2p
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
|
"github.com/mudler/edgevpn/pkg/node"
|
|
"github.com/mudler/edgevpn/pkg/protocol"
|
|
"github.com/mudler/edgevpn/pkg/types"
|
|
"github.com/phayes/freeport"
|
|
|
|
"github.com/ipfs/go-log"
|
|
"github.com/mudler/edgevpn/pkg/config"
|
|
"github.com/mudler/edgevpn/pkg/services"
|
|
zlog "github.com/rs/zerolog/log"
|
|
|
|
"github.com/mudler/edgevpn/pkg/logger"
|
|
)
|
|
|
|
func GenerateToken() string {
|
|
// Generates a new config and exit
|
|
newData := node.GenerateNewConnectionData(900)
|
|
return newData.Base64()
|
|
}
|
|
|
|
func allocateLocalService(ctx context.Context, node *node.Node, listenAddr, service string) error {
|
|
|
|
zlog.Info().Msgf("Allocating service '%s' on: %s", service, listenAddr)
|
|
// Open local port for listening
|
|
l, err := net.Listen("tcp", listenAddr)
|
|
if err != nil {
|
|
zlog.Error().Err(err).Msg("Error listening")
|
|
return err
|
|
}
|
|
// ll.Info("Binding local port on", srcaddr)
|
|
|
|
ledger, _ := node.Ledger()
|
|
|
|
// Announce ourselves so nodes accepts our connection
|
|
ledger.Announce(
|
|
ctx,
|
|
10*time.Second,
|
|
func() {
|
|
// Retrieve current ID for ip in the blockchain
|
|
_, found := ledger.GetKey(protocol.UsersLedgerKey, node.Host().ID().String())
|
|
// If mismatch, update the blockchain
|
|
if !found {
|
|
updatedMap := map[string]interface{}{}
|
|
updatedMap[node.Host().ID().String()] = &types.User{
|
|
PeerID: node.Host().ID().String(),
|
|
Timestamp: time.Now().String(),
|
|
}
|
|
ledger.Add(protocol.UsersLedgerKey, updatedMap)
|
|
}
|
|
},
|
|
)
|
|
|
|
defer l.Close()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return errors.New("context canceled")
|
|
default:
|
|
zlog.Debug().Msg("New for connection")
|
|
// Listen for an incoming connection.
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
fmt.Println("Error accepting: ", err.Error())
|
|
continue
|
|
}
|
|
|
|
// ll.Info("New connection from", l.Addr().String())
|
|
// Handle connections in a new goroutine, forwarding to the p2p service
|
|
go func() {
|
|
// Retrieve current ID for ip in the blockchain
|
|
existingValue, found := ledger.GetKey(protocol.ServicesLedgerKey, service)
|
|
service := &types.Service{}
|
|
existingValue.Unmarshal(service)
|
|
// If mismatch, update the blockchain
|
|
if !found {
|
|
zlog.Error().Msg("Service not found on blockchain")
|
|
conn.Close()
|
|
// ll.Debugf("service '%s' not found on blockchain", serviceID)
|
|
return
|
|
}
|
|
|
|
// Decode the Peer
|
|
d, err := peer.Decode(service.PeerID)
|
|
if err != nil {
|
|
zlog.Error().Msg("cannot decode peer")
|
|
|
|
conn.Close()
|
|
// ll.Debugf("could not decode peer '%s'", service.PeerID)
|
|
return
|
|
}
|
|
|
|
// Open a stream
|
|
stream, err := node.Host().NewStream(ctx, d, protocol.ServiceProtocol.ID())
|
|
if err != nil {
|
|
zlog.Error().Msg("cannot open stream peer")
|
|
|
|
conn.Close()
|
|
// ll.Debugf("could not open stream '%s'", err.Error())
|
|
return
|
|
}
|
|
// ll.Debugf("(service %s) Redirecting", serviceID, l.Addr().String())
|
|
zlog.Info().Msgf("Redirecting %s to %s", conn.LocalAddr().String(), stream.Conn().RemoteMultiaddr().String())
|
|
closer := make(chan struct{}, 2)
|
|
go copyStream(closer, stream, conn)
|
|
go copyStream(closer, conn, stream)
|
|
<-closer
|
|
|
|
stream.Close()
|
|
conn.Close()
|
|
// ll.Infof("(service %s) Done handling %s", serviceID, l.Addr().String())
|
|
}()
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func copyStream(closer chan struct{}, dst io.Writer, src io.Reader) {
|
|
defer func() { closer <- struct{}{} }() // connection is closed, send signal to stop proxy
|
|
io.Copy(dst, src)
|
|
}
|
|
|
|
// This is the main of the server (which keeps the env variable updated)
|
|
// This starts a goroutine that keeps LLAMACPP_GRPC_SERVERS updated with the discovered services
|
|
func LLamaCPPRPCServerDiscoverer(ctx context.Context, token string) error {
|
|
tunnels, err := discoveryTunnels(ctx, token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
totalTunnels := []string{}
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
zlog.Error().Msg("Discoverer stopped")
|
|
return
|
|
case tunnel := <-tunnels:
|
|
|
|
totalTunnels = append(totalTunnels, tunnel)
|
|
os.Setenv("LLAMACPP_GRPC_SERVERS", strings.Join(totalTunnels, ","))
|
|
zlog.Debug().Msgf("setting LLAMACPP_GRPC_SERVERS to %s", strings.Join(totalTunnels, ","))
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func discoveryTunnels(ctx context.Context, token string) (chan string, error) {
|
|
tunnels := make(chan string)
|
|
|
|
nodeOpts, err := newNodeOpts(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
n, err := node.New(nodeOpts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating a new node: %w", err)
|
|
}
|
|
err = n.Start(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating a new node: %w", err)
|
|
}
|
|
ledger, err := n.Ledger()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating a new node: %w", err)
|
|
}
|
|
|
|
// get new services, allocate and return to the channel
|
|
go func() {
|
|
emitted := map[string]bool{}
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
zlog.Error().Msg("Discoverer stopped")
|
|
return
|
|
default:
|
|
time.Sleep(5 * time.Second)
|
|
zlog.Debug().Msg("Searching for workers")
|
|
|
|
data := ledger.LastBlock().Storage["services_localai"]
|
|
for k := range data {
|
|
zlog.Info().Msgf("Found worker %s", k)
|
|
if _, found := emitted[k]; !found {
|
|
emitted[k] = true
|
|
//discoveredPeers <- k
|
|
port, err := freeport.GetFreePort()
|
|
if err != nil {
|
|
fmt.Print(err)
|
|
}
|
|
tunnelAddress := fmt.Sprintf("127.0.0.1:%d", port)
|
|
go allocateLocalService(ctx, n, tunnelAddress, k)
|
|
tunnels <- tunnelAddress
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return tunnels, err
|
|
}
|
|
|
|
// This is the P2P worker main
|
|
func BindLLamaCPPWorker(ctx context.Context, host, port, token string) error {
|
|
llger := logger.New(log.LevelFatal)
|
|
|
|
nodeOpts, err := newNodeOpts(token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// generate a random string for the name
|
|
name := utils.RandString(10)
|
|
|
|
// Register the service
|
|
nodeOpts = append(nodeOpts,
|
|
services.RegisterService(llger, time.Duration(60)*time.Second, name, fmt.Sprintf("%s:%s", host, port))...)
|
|
n, err := node.New(nodeOpts...)
|
|
if err != nil {
|
|
return fmt.Errorf("creating a new node: %w", err)
|
|
}
|
|
|
|
err = n.Start(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("creating a new node: %w", err)
|
|
}
|
|
|
|
ledger, err := n.Ledger()
|
|
if err != nil {
|
|
return fmt.Errorf("creating a new node: %w", err)
|
|
}
|
|
|
|
ledger.Announce(
|
|
ctx,
|
|
10*time.Second,
|
|
func() {
|
|
// Retrieve current ID for ip in the blockchain
|
|
_, found := ledger.GetKey("services_localai", name)
|
|
// If mismatch, update the blockchain
|
|
if !found {
|
|
updatedMap := map[string]interface{}{}
|
|
updatedMap[name] = "p2p"
|
|
ledger.Add("services_localai", updatedMap)
|
|
}
|
|
},
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
func newNodeOpts(token string) ([]node.Option, error) {
|
|
llger := logger.New(log.LevelFatal)
|
|
defaultInterval := 10 * time.Second
|
|
|
|
loglevel := "info"
|
|
|
|
c := config.Config{
|
|
Limit: config.ResourceLimit{
|
|
Enable: true,
|
|
MaxConns: 100,
|
|
},
|
|
NetworkToken: token,
|
|
LowProfile: false,
|
|
LogLevel: loglevel,
|
|
Libp2pLogLevel: "fatal",
|
|
Ledger: config.Ledger{
|
|
SyncInterval: defaultInterval,
|
|
AnnounceInterval: defaultInterval,
|
|
},
|
|
NAT: config.NAT{
|
|
Service: true,
|
|
Map: true,
|
|
RateLimit: true,
|
|
RateLimitGlobal: 10,
|
|
RateLimitPeer: 10,
|
|
RateLimitInterval: defaultInterval,
|
|
},
|
|
Discovery: config.Discovery{
|
|
DHT: true,
|
|
MDNS: true,
|
|
Interval: 30 * time.Second,
|
|
},
|
|
Connection: config.Connection{
|
|
HolePunch: true,
|
|
AutoRelay: true,
|
|
MaxConnections: 100,
|
|
},
|
|
}
|
|
|
|
nodeOpts, _, err := c.ToOpts(llger)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing options: %w", err)
|
|
}
|
|
|
|
nodeOpts = append(nodeOpts, services.Alive(30*time.Second, 900*time.Second, 15*time.Minute)...)
|
|
|
|
return nodeOpts, nil
|
|
}
|