LocalAI/pkg/model/initializers.go

303 lines
8.9 KiB
Go
Raw Normal View History

2023-05-11 14:34:16 +00:00
package model
import (
"context"
2023-05-11 14:34:16 +00:00
"fmt"
"os"
"path/filepath"
2023-05-11 14:34:16 +00:00
"strings"
"time"
2023-05-11 14:34:16 +00:00
rwkv "github.com/donomii/go-rwkv.cpp"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/langchain"
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
"github.com/go-skynet/LocalAI/pkg/tts"
2023-05-11 14:34:16 +00:00
bloomz "github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
"github.com/hashicorp/go-multierror"
"github.com/hpcloud/tail"
"github.com/phayes/freeport"
2023-05-11 14:34:16 +00:00
"github.com/rs/zerolog/log"
process "github.com/mudler/go-processmanager"
2023-05-11 14:34:16 +00:00
)
const tokenizerSuffix = ".tokenizer.json"
const (
LlamaBackend = "llama"
BloomzBackend = "bloomz"
StarcoderBackend = "starcoder"
GPTJBackend = "gptj"
DollyBackend = "dolly"
MPTBackend = "mpt"
GPTNeoXBackend = "gptneox"
ReplitBackend = "replit"
Gpt2Backend = "gpt2"
Gpt4AllLlamaBackend = "gpt4all-llama"
Gpt4AllMptBackend = "gpt4all-mpt"
Gpt4AllJBackend = "gpt4all-j"
Gpt4All = "gpt4all"
FalconBackend = "falcon"
BertEmbeddingsBackend = "bert-embeddings"
RwkvBackend = "rwkv"
WhisperBackend = "whisper"
StableDiffusionBackend = "stablediffusion"
PiperBackend = "piper"
LCHuggingFaceBackend = "langchain-huggingface"
//GGLLMFalconBackend = "falcon"
2023-05-11 14:34:16 +00:00
)
var autoLoadBackends []string = []string{
2023-05-11 14:34:16 +00:00
LlamaBackend,
Gpt4All,
2023-05-11 14:34:16 +00:00
RwkvBackend,
WhisperBackend,
2023-05-11 14:34:16 +00:00
BertEmbeddingsBackend,
GPTNeoXBackend,
GPTJBackend,
Gpt2Backend,
DollyBackend,
MPTBackend,
ReplitBackend,
2023-05-11 18:20:07 +00:00
StarcoderBackend,
FalconBackend,
BloomzBackend,
2023-05-11 18:20:07 +00:00
}
2023-05-11 14:34:16 +00:00
var bertEmbeddings = func(modelFile string) (interface{}, error) {
return bert.New(modelFile)
}
var bloomzLM = func(modelFile string) (interface{}, error) {
return bloomz.New(modelFile)
}
var stableDiffusion = func(assetDir string) (interface{}, error) {
return stablediffusion.New(assetDir)
}
func piperTTS(assetDir string) func(s string) (interface{}, error) {
return func(s string) (interface{}, error) {
return tts.New(assetDir)
}
}
2023-05-11 14:34:16 +00:00
var whisperModel = func(modelFile string) (interface{}, error) {
return whisper.New(modelFile)
}
var lcHuggingFace = func(repoId string) (interface{}, error) {
return langchain.NewHuggingFace(repoId)
}
// func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
// return func(s string) (interface{}, error) {
// return llama.New(s, opts...)
// }
// }
2023-05-11 14:34:16 +00:00
// func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) {
// return func(s string) (interface{}, error) {
// return gpt4all.New(s, opts...)
// }
// }
2023-05-11 14:34:16 +00:00
func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) {
return func(s string) (interface{}, error) {
log.Debug().Msgf("Loading RWKV", s, tokenFile)
2023-05-11 14:34:16 +00:00
model := rwkv.LoadFiles(s, tokenFile, threads)
if model == nil {
return nil, fmt.Errorf("could not load model")
}
return model, nil
}
}
// starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (interface{}, error) {
return func(s string) (interface{}, error) {
log.Debug().Msgf("Loading GRPC Model", backend, *o)
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
// Make sure the process is executable
if err := os.Chmod(grpcProcess, 0755); err != nil {
return nil, err
}
log.Debug().Msgf("Loading GRPC Process", grpcProcess)
port, err := freeport.GetFreePort()
if err != nil {
return nil, err
}
serverAddress := fmt.Sprintf("localhost:%d", port)
log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress)
grpcControlProcess := process.New(
process.WithTemporaryStateDir(),
process.WithName(grpcProcess),
process.WithArgs("--addr", serverAddress))
ml.grpcProcesses[o.modelFile] = grpcControlProcess
if err := grpcControlProcess.Run(); err != nil {
return nil, err
}
go func() {
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
if err != nil {
log.Debug().Msgf("Could not tail stderr")
}
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
}
}()
go func() {
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
if err != nil {
log.Debug().Msgf("Could not tail stdout")
}
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text)
}
}()
log.Debug().Msgf("GRPC Service Started")
client := grpc.NewClient(serverAddress)
// Wait for the service to start up
ready := false
for i := 0; i < 10; i++ {
if client.HealthCheck(context.Background()) {
log.Debug().Msgf("GRPC Service Ready")
ready = true
break
}
time.Sleep(1 * time.Second)
}
if !ready {
log.Debug().Msgf("GRPC Service NOT ready")
log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive())
log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:"))
log.Debug().Msgf(grpcControlProcess.ExitCode())
return nil, fmt.Errorf("grpc service not ready")
}
options := *o.gRPCOptions
options.Model = s
log.Debug().Msgf("GRPC: Loading model with options: %+v", options)
res, err := client.LoadModel(context.TODO(), &options)
if err != nil {
return nil, err
}
if !res.Success {
return nil, fmt.Errorf("could not load model: %s", res.Message)
}
return client, nil
}
}
func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err error) {
//backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (model interface{}, err error) {
o := NewOptions(opts...)
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
switch strings.ToLower(o.backendString) {
2023-05-11 14:34:16 +00:00
case LlamaBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o))
2023-05-11 14:34:16 +00:00
case BloomzBackend:
return ml.LoadModel(o.modelFile, bloomzLM)
case GPTJBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(GPTJBackend, o))
2023-05-11 14:34:16 +00:00
case DollyBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(DollyBackend, o))
case MPTBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(MPTBackend, o))
2023-05-11 14:34:16 +00:00
case Gpt2Backend:
return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt2Backend, o))
case FalconBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(FalconBackend, o))
2023-05-12 09:36:35 +00:00
case GPTNeoXBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(GPTNeoXBackend, o))
2023-05-12 09:36:35 +00:00
case ReplitBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(ReplitBackend, o))
case StableDiffusionBackend:
return ml.LoadModel(o.modelFile, stableDiffusion)
case PiperBackend:
return ml.LoadModel(o.modelFile, piperTTS(filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data")))
2023-05-11 18:20:07 +00:00
case StarcoderBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(StarcoderBackend, o))
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all")
return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o))
// return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all"))))
2023-05-11 14:34:16 +00:00
case BertEmbeddingsBackend:
return ml.LoadModel(o.modelFile, bertEmbeddings)
2023-05-11 14:34:16 +00:00
case RwkvBackend:
return ml.LoadModel(o.modelFile, rwkvLM(filepath.Join(ml.ModelPath, o.modelFile+tokenizerSuffix), o.threads))
2023-05-11 14:34:16 +00:00
case WhisperBackend:
return ml.LoadModel(o.modelFile, whisperModel)
case LCHuggingFaceBackend:
return ml.LoadModel(o.modelFile, lcHuggingFace)
2023-05-11 14:34:16 +00:00
default:
return nil, fmt.Errorf("backend unsupported: %s", o.backendString)
2023-05-11 14:34:16 +00:00
}
}
func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) {
o := NewOptions(opts...)
log.Debug().Msgf("Loading model '%s' greedly", o.modelFile)
2023-05-11 14:34:16 +00:00
ml.mu.Lock()
m, exists := ml.models[o.modelFile]
2023-05-11 14:34:16 +00:00
if exists {
log.Debug().Msgf("Model '%s' already loaded", o.modelFile)
2023-05-11 14:34:16 +00:00
ml.mu.Unlock()
return m, nil
}
ml.mu.Unlock()
var err error
for _, b := range autoLoadBackends {
2023-05-11 14:34:16 +00:00
if b == BloomzBackend || b == WhisperBackend || b == RwkvBackend { // do not autoload bloomz/whisper/rwkv
continue
}
log.Debug().Msgf("[%s] Attempting to load", b)
model, modelerr := ml.BackendLoader(
WithBackendString(b),
WithModelFile(o.modelFile),
WithLoadGRPCOpts(o.gRPCOptions),
WithThreads(o.threads),
WithAssetDir(o.assetDir),
)
2023-05-11 14:34:16 +00:00
if modelerr == nil && model != nil {
log.Debug().Msgf("[%s] Loads OK", b)
return model, nil
} else if modelerr != nil {
err = multierror.Append(err, modelerr)
log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error())
}
}
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
}