Keep whisper models in memory (#233)

This commit is contained in:
Matthew Campbell 2023-05-11 19:05:07 +07:00 committed by GitHub
parent 6b5e2b2bf5
commit 032dee256f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 12 deletions

View File

@ -436,7 +436,12 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Audio file copied to: %+v", dst) log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := whisper.Transcript(filepath.Join(loader.ModelPath, config.Model), dst, input.Language) whisperModel, err := loader.WhisperLoader("whisper", config.Model)
if err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
tr, err := whisper.Transcript(whisperModel, dst, input.Language)
if err != nil { if err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
} }

View File

@ -11,6 +11,7 @@ import (
"text/template" "text/template"
rwkv "github.com/donomii/go-rwkv.cpp" rwkv "github.com/donomii/go-rwkv.cpp"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
bloomz "github.com/go-skynet/bloomz.cpp" bloomz "github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp" bert "github.com/go-skynet/go-bert.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp"
@ -32,9 +33,9 @@ type ModelLoader struct {
redpajama map[string]*gpt2.RedPajama redpajama map[string]*gpt2.RedPajama
rwkv map[string]*rwkv.RwkvState rwkv map[string]*rwkv.RwkvState
bloomz map[string]*bloomz.Bloomz bloomz map[string]*bloomz.Bloomz
bert map[string]*bert.Bert
bert map[string]*bert.Bert promptsTemplates map[string]*template.Template
promptsTemplates map[string]*template.Template whisperModels map[string]whisper.Model
} }
func NewModelLoader(modelPath string) *ModelLoader { func NewModelLoader(modelPath string) *ModelLoader {
@ -50,6 +51,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
bloomz: make(map[string]*bloomz.Bloomz), bloomz: make(map[string]*bloomz.Bloomz),
bert: make(map[string]*bert.Bert), bert: make(map[string]*bert.Bert),
promptsTemplates: make(map[string]*template.Template), promptsTemplates: make(map[string]*template.Template),
whisperModels: make(map[string]whisper.Model),
} }
} }
@ -422,6 +424,33 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
return model, err return model, err
} }
func (ml *ModelLoader) LoadWhisperModel(modelName string) (whisper.Model, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist -- %s", modelName)
}
if m, ok := ml.whisperModels[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := whisper.New(modelFile)
if err != nil {
return nil, err
}
ml.whisperModels[modelName] = model
return model, err
}
const tokenizerSuffix = ".tokenizer.json" const tokenizerSuffix = ".tokenizer.json"
var loadedModels map[string]interface{} = map[string]interface{}{} var loadedModels map[string]interface{} = map[string]interface{}{}
@ -452,6 +481,16 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
} }
} }
func (ml *ModelLoader) WhisperLoader(backendString string, modelFile string) (model whisper.Model, err error) {
//TODO expose more whisper options in next PR
switch strings.ToLower(backendString) {
case "whisper":
return ml.LoadWhisperModel(modelFile)
default:
return nil, fmt.Errorf("whisper backend unsupported: %s", backendString)
}
}
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) { func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
updateModels := func(model interface{}) { updateModels := func(model interface{}) {
muModels.Lock() muModels.Lock()

View File

@ -28,7 +28,7 @@ func audioToWav(src, dst string) error {
return nil return nil
} }
func Transcript(modelpath, audiopath, language string) (string, error) { func Transcript(model whisper.Model, audiopath, language string) (string, error) {
dir, err := os.MkdirTemp("", "whisper") dir, err := os.MkdirTemp("", "whisper")
if err != nil { if err != nil {
@ -58,13 +58,6 @@ func Transcript(modelpath, audiopath, language string) (string, error) {
data := buf.AsFloat32Buffer().Data data := buf.AsFloat32Buffer().Data
// Load the model
model, err := whisper.New(modelpath)
if err != nil {
return "", err
}
defer model.Close()
// Process samples // Process samples
context, err := model.NewContext() context, err := model.NewContext()
if err != nil { if err != nil {