mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
Keep whisper models in memory (#233)
This commit is contained in:
parent
6b5e2b2bf5
commit
032dee256f
@ -436,7 +436,12 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
tr, err := whisper.Transcript(filepath.Join(loader.ModelPath, config.Model), dst, input.Language)
|
||||
whisperModel, err := loader.WhisperLoader("whisper", config.Model)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
||||
tr, err := whisper.Transcript(whisperModel, dst, input.Language)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"text/template"
|
||||
|
||||
rwkv "github.com/donomii/go-rwkv.cpp"
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
bloomz "github.com/go-skynet/bloomz.cpp"
|
||||
bert "github.com/go-skynet/go-bert.cpp"
|
||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||
@ -32,9 +33,9 @@ type ModelLoader struct {
|
||||
redpajama map[string]*gpt2.RedPajama
|
||||
rwkv map[string]*rwkv.RwkvState
|
||||
bloomz map[string]*bloomz.Bloomz
|
||||
|
||||
bert map[string]*bert.Bert
|
||||
promptsTemplates map[string]*template.Template
|
||||
bert map[string]*bert.Bert
|
||||
promptsTemplates map[string]*template.Template
|
||||
whisperModels map[string]whisper.Model
|
||||
}
|
||||
|
||||
func NewModelLoader(modelPath string) *ModelLoader {
|
||||
@ -50,6 +51,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
||||
bloomz: make(map[string]*bloomz.Bloomz),
|
||||
bert: make(map[string]*bert.Bert),
|
||||
promptsTemplates: make(map[string]*template.Template),
|
||||
whisperModels: make(map[string]whisper.Model),
|
||||
}
|
||||
}
|
||||
|
||||
@ -422,6 +424,33 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
|
||||
return model, err
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadWhisperModel(modelName string) (whisper.Model, error) {
|
||||
ml.mu.Lock()
|
||||
defer ml.mu.Unlock()
|
||||
|
||||
// Check if we already have a loaded model
|
||||
if !ml.ExistsInModelPath(modelName) {
|
||||
return nil, fmt.Errorf("model does not exist -- %s", modelName)
|
||||
}
|
||||
|
||||
if m, ok := ml.whisperModels[modelName]; ok {
|
||||
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Load the model and keep it in memory for later use
|
||||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
|
||||
|
||||
model, err := whisper.New(modelFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ml.whisperModels[modelName] = model
|
||||
return model, err
|
||||
}
|
||||
|
||||
const tokenizerSuffix = ".tokenizer.json"
|
||||
|
||||
var loadedModels map[string]interface{} = map[string]interface{}{}
|
||||
@ -452,6 +481,16 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) WhisperLoader(backendString string, modelFile string) (model whisper.Model, err error) {
|
||||
//TODO expose more whisper options in next PR
|
||||
switch strings.ToLower(backendString) {
|
||||
case "whisper":
|
||||
return ml.LoadWhisperModel(modelFile)
|
||||
default:
|
||||
return nil, fmt.Errorf("whisper backend unsupported: %s", backendString)
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
|
||||
updateModels := func(model interface{}) {
|
||||
muModels.Lock()
|
||||
|
@ -28,7 +28,7 @@ func audioToWav(src, dst string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Transcript(modelpath, audiopath, language string) (string, error) {
|
||||
func Transcript(model whisper.Model, audiopath, language string) (string, error) {
|
||||
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
@ -58,13 +58,6 @@ func Transcript(modelpath, audiopath, language string) (string, error) {
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
// Load the model
|
||||
model, err := whisper.New(modelpath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Process samples
|
||||
context, err := model.NewContext()
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user