mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
Keep whisper models in memory (#233)
This commit is contained in:
parent
6b5e2b2bf5
commit
032dee256f
@ -436,7 +436,12 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
|
|||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
|
||||||
tr, err := whisper.Transcript(filepath.Join(loader.ModelPath, config.Model), dst, input.Language)
|
whisperModel, err := loader.WhisperLoader("whisper", config.Model)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
tr, err := whisper.Transcript(whisperModel, dst, input.Language)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
rwkv "github.com/donomii/go-rwkv.cpp"
|
rwkv "github.com/donomii/go-rwkv.cpp"
|
||||||
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
bloomz "github.com/go-skynet/bloomz.cpp"
|
bloomz "github.com/go-skynet/bloomz.cpp"
|
||||||
bert "github.com/go-skynet/go-bert.cpp"
|
bert "github.com/go-skynet/go-bert.cpp"
|
||||||
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
gpt2 "github.com/go-skynet/go-gpt2.cpp"
|
||||||
@ -32,9 +33,9 @@ type ModelLoader struct {
|
|||||||
redpajama map[string]*gpt2.RedPajama
|
redpajama map[string]*gpt2.RedPajama
|
||||||
rwkv map[string]*rwkv.RwkvState
|
rwkv map[string]*rwkv.RwkvState
|
||||||
bloomz map[string]*bloomz.Bloomz
|
bloomz map[string]*bloomz.Bloomz
|
||||||
|
bert map[string]*bert.Bert
|
||||||
bert map[string]*bert.Bert
|
promptsTemplates map[string]*template.Template
|
||||||
promptsTemplates map[string]*template.Template
|
whisperModels map[string]whisper.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModelLoader(modelPath string) *ModelLoader {
|
func NewModelLoader(modelPath string) *ModelLoader {
|
||||||
@ -50,6 +51,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
|
|||||||
bloomz: make(map[string]*bloomz.Bloomz),
|
bloomz: make(map[string]*bloomz.Bloomz),
|
||||||
bert: make(map[string]*bert.Bert),
|
bert: make(map[string]*bert.Bert),
|
||||||
promptsTemplates: make(map[string]*template.Template),
|
promptsTemplates: make(map[string]*template.Template),
|
||||||
|
whisperModels: make(map[string]whisper.Model),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -422,6 +424,33 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
|
|||||||
return model, err
|
return model, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) LoadWhisperModel(modelName string) (whisper.Model, error) {
|
||||||
|
ml.mu.Lock()
|
||||||
|
defer ml.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if we already have a loaded model
|
||||||
|
if !ml.ExistsInModelPath(modelName) {
|
||||||
|
return nil, fmt.Errorf("model does not exist -- %s", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m, ok := ml.whisperModels[modelName]; ok {
|
||||||
|
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the model and keep it in memory for later use
|
||||||
|
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||||
|
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
|
||||||
|
|
||||||
|
model, err := whisper.New(modelFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ml.whisperModels[modelName] = model
|
||||||
|
return model, err
|
||||||
|
}
|
||||||
|
|
||||||
const tokenizerSuffix = ".tokenizer.json"
|
const tokenizerSuffix = ".tokenizer.json"
|
||||||
|
|
||||||
var loadedModels map[string]interface{} = map[string]interface{}{}
|
var loadedModels map[string]interface{} = map[string]interface{}{}
|
||||||
@ -452,6 +481,16 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ml *ModelLoader) WhisperLoader(backendString string, modelFile string) (model whisper.Model, err error) {
|
||||||
|
//TODO expose more whisper options in next PR
|
||||||
|
switch strings.ToLower(backendString) {
|
||||||
|
case "whisper":
|
||||||
|
return ml.LoadWhisperModel(modelFile)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("whisper backend unsupported: %s", backendString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
|
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
|
||||||
updateModels := func(model interface{}) {
|
updateModels := func(model interface{}) {
|
||||||
muModels.Lock()
|
muModels.Lock()
|
||||||
|
@ -28,7 +28,7 @@ func audioToWav(src, dst string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Transcript(modelpath, audiopath, language string) (string, error) {
|
func Transcript(model whisper.Model, audiopath, language string) (string, error) {
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
dir, err := os.MkdirTemp("", "whisper")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -58,13 +58,6 @@ func Transcript(modelpath, audiopath, language string) (string, error) {
|
|||||||
|
|
||||||
data := buf.AsFloat32Buffer().Data
|
data := buf.AsFloat32Buffer().Data
|
||||||
|
|
||||||
// Load the model
|
|
||||||
model, err := whisper.New(modelpath)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer model.Close()
|
|
||||||
|
|
||||||
// Process samples
|
// Process samples
|
||||||
context, err := model.NewContext()
|
context, err := model.NewContext()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user