diff --git a/api/openai.go b/api/openai.go index 71cb0307..403a03ba 100644 --- a/api/openai.go +++ b/api/openai.go @@ -737,7 +737,7 @@ func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { log.Debug().Msgf("Trascribed: %+v", tr) // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) + return c.Status(http.StatusOK).JSON(tr) } } diff --git a/pkg/whisper/whisper.go b/pkg/whisper/whisper.go index 882c96ab..63e8cc5b 100644 --- a/pkg/whisper/whisper.go +++ b/pkg/whisper/whisper.go @@ -5,11 +5,25 @@ import ( "os" "os/exec" "path/filepath" + "time" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" wav "github.com/go-audio/wav" ) +type Segment struct { + Id int `json:"id"` + Start time.Duration `json:"start"` + End time.Duration `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` +} + +type Result struct { + Segments []Segment `json:"segments"` + Text string `json:"text"` +} + func sh(c string) (string, error) { cmd := exec.Command("/bin/sh", "-c", c) cmd.Env = os.Environ() @@ -28,24 +42,25 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (string, error) { +func Transcript(model whisper.Model, audiopath, language string, threads uint) (Result, error) { + res := Result{} dir, err := os.MkdirTemp("", "whisper") if err != nil { - return "", err + return res, err } defer os.RemoveAll(dir) convertedPath := filepath.Join(dir, "converted.wav") if err := audioToWav(audiopath, convertedPath); err != nil { - return "", err + return res, err } // Open samples fh, err := os.Open(convertedPath) if err != nil { - return "", err + return res, err } defer fh.Close() @@ -53,7 +68,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( d := wav.NewDecoder(fh) buf, err := d.FullPCMBuffer() if err != nil { - return "", err + return res, err } data := buf.AsFloat32Buffer().Data @@ -61,7 +76,7 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( // Process samples context, err := model.NewContext() if err != nil { - return "", err + return res, err } @@ -74,17 +89,25 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( } if err := context.Process(data, nil, nil); err != nil { - return "", err + return res, err } - text := "" for { - segment, err := context.NextSegment() + s, err := context.NextSegment() if err != nil { break } - text += segment.Text + + var tokens []int + for _, t := range(s.Tokens) { + tokens = append(tokens, t.Id) + } + + segment := Segment{Id: s.Num, Text: s.Text, Start:s.Start, End: s.End, Tokens: tokens} + res.Segments = append(res.Segments, segment) + + res.Text += s.Text } - return text, nil + return res, nil }