feat: [whisper] Partial support for verbose_json format in transcribe endpoint (#721)

This commit is contained in:
Luis López 2023-07-04 14:31:31 +02:00 committed by GitHub
parent f3063f98d3
commit a6839fd238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 12 deletions

View File

@ -737,7 +737,7 @@ func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
log.Debug().Msgf("Trascribed: %+v", tr) log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here // TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) return c.Status(http.StatusOK).JSON(tr)
} }
} }

View File

@ -5,11 +5,25 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"time"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav" wav "github.com/go-audio/wav"
) )
type Segment struct {
Id int `json:"id"`
Start time.Duration `json:"start"`
End time.Duration `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
}
type Result struct {
Segments []Segment `json:"segments"`
Text string `json:"text"`
}
func sh(c string) (string, error) { func sh(c string) (string, error) {
cmd := exec.Command("/bin/sh", "-c", c) cmd := exec.Command("/bin/sh", "-c", c)
cmd.Env = os.Environ() cmd.Env = os.Environ()
@ -28,24 +42,25 @@ func audioToWav(src, dst string) error {
return nil return nil
} }
func Transcript(model whisper.Model, audiopath, language string, threads uint) (string, error) { func Transcript(model whisper.Model, audiopath, language string, threads uint) (Result, error) {
res := Result{}
dir, err := os.MkdirTemp("", "whisper") dir, err := os.MkdirTemp("", "whisper")
if err != nil { if err != nil {
return "", err return res, err
} }
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav") convertedPath := filepath.Join(dir, "converted.wav")
if err := audioToWav(audiopath, convertedPath); err != nil { if err := audioToWav(audiopath, convertedPath); err != nil {
return "", err return res, err
} }
// Open samples // Open samples
fh, err := os.Open(convertedPath) fh, err := os.Open(convertedPath)
if err != nil { if err != nil {
return "", err return res, err
} }
defer fh.Close() defer fh.Close()
@ -53,7 +68,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
d := wav.NewDecoder(fh) d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer() buf, err := d.FullPCMBuffer()
if err != nil { if err != nil {
return "", err return res, err
} }
data := buf.AsFloat32Buffer().Data data := buf.AsFloat32Buffer().Data
@ -61,7 +76,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
// Process samples // Process samples
context, err := model.NewContext() context, err := model.NewContext()
if err != nil { if err != nil {
return "", err return res, err
} }
@ -74,17 +89,25 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
} }
if err := context.Process(data, nil, nil); err != nil { if err := context.Process(data, nil, nil); err != nil {
return "", err return res, err
} }
text := ""
for { for {
segment, err := context.NextSegment() s, err := context.NextSegment()
if err != nil { if err != nil {
break break
} }
text += segment.Text
var tokens []int
for _, t := range(s.Tokens) {
tokens = append(tokens, t.Id)
} }
return text, nil segment := Segment{Id: s.Num, Text: s.Text, Start:s.Start, End: s.End, Tokens: tokens}
res.Segments = append(res.Segments, segment)
res.Text += s.Text
}
return res, nil
} }