feat: [whisper] Partial support for verbose_json format in transcribe endpoint (#721)

This commit is contained in:
Luis López 2023-07-04 14:31:31 +02:00 committed by GitHub
parent f3063f98d3
commit a6839fd238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 12 deletions

View File

@ -737,7 +737,7 @@ func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr})
return c.Status(http.StatusOK).JSON(tr)
}
}

View File

@ -5,11 +5,25 @@ import (
"os"
"os/exec"
"path/filepath"
"time"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav"
)
type Segment struct {
Id int `json:"id"`
Start time.Duration `json:"start"`
End time.Duration `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
}
type Result struct {
Segments []Segment `json:"segments"`
Text string `json:"text"`
}
func sh(c string) (string, error) {
cmd := exec.Command("/bin/sh", "-c", c)
cmd.Env = os.Environ()
@ -28,24 +42,25 @@ func audioToWav(src, dst string) error {
return nil
}
func Transcript(model whisper.Model, audiopath, language string, threads uint) (string, error) {
func Transcript(model whisper.Model, audiopath, language string, threads uint) (Result, error) {
res := Result{}
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return "", err
return res, err
}
defer os.RemoveAll(dir)
convertedPath := filepath.Join(dir, "converted.wav")
if err := audioToWav(audiopath, convertedPath); err != nil {
return "", err
return res, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return "", err
return res, err
}
defer fh.Close()
@ -53,7 +68,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return "", err
return res, err
}
data := buf.AsFloat32Buffer().Data
@ -61,7 +76,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
// Process samples
context, err := model.NewContext()
if err != nil {
return "", err
return res, err
}
@ -74,17 +89,25 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
}
if err := context.Process(data, nil, nil); err != nil {
return "", err
return res, err
}
text := ""
for {
segment, err := context.NextSegment()
s, err := context.NextSegment()
if err != nil {
break
}
text += segment.Text
var tokens []int
for _, t := range(s.Tokens) {
tokens = append(tokens, t.Id)
}
segment := Segment{Id: s.Num, Text: s.Text, Start:s.Start, End: s.End, Tokens: tokens}
res.Segments = append(res.Segments, segment)
res.Text += s.Text
}
return text, nil
return res, nil
}