Skip to content

Commit

Permalink
Vad detection.
Browse files Browse the repository at this point in the history
  • Loading branch information
boocmp committed Aug 21, 2024
1 parent e7e9985 commit 2741d3a
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 62 deletions.
38 changes: 8 additions & 30 deletions src/runners/audio_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def transcribe_audio(self, audio, lang):
return {"text": text}


from pydantic import BaseModel
import numpy as np
import io
from datetime import datetime
from faster_whisper.vad import get_speech_timestamps, collect_chunks
from pydantic import BaseModel


class BatchInput(BaseModel):
Expand All @@ -60,20 +60,6 @@ class BatchOutput(BaseModel):
restore_time: float


class BatchItem(BaseModel):
start_time: float
end_time: float
chunks_count: int
transcription: str = ""

def add(self, word):
if (
self.chunks_count > 0
and word.start >= self.start_time
and word.end <= self.end_time
):
self.transcription += word.word

"""
class BatchableAudioTranscriber(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
Expand Down Expand Up @@ -177,11 +163,14 @@ class BatchableAudioTranscriber(bentoml.Runnable):
SUPPORTS_CPU_MULTI_THREADING = True

def __init__(self):
pass
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = Wav2Vec2Processor.from_pretrained(
# "facebook/wav2vec2-base-960h"
"facebook/wav2vec2-large-960h-lv60-self"
)
self.model = Wav2Vec2ForCTC.from_pretrained(
# "facebook/wav2vec2-base-960h"
"facebook/wav2vec2-large-960h-lv60-self"
).cuda()

Expand All @@ -194,27 +183,15 @@ def transcribe(self, audios):
logits = self.model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcriptions = self.processor.batch_decode(predicted_ids)
print(transcriptions)

return transcriptions

@bentoml.Runnable.method(batchable=True)
def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
result = []

# merging audio
ts = datetime.now()
audio_batch = []
for input in inputs:
wav = decode_audio(io.BytesIO(input.audio))
chunks = get_speech_timestamps(wav)
if len(chunks) == 0:
audio_batch.append(np.zeros(16000, dtype=np.float32))
else:
wav = collect_chunks(wav, chunks=chunks)
audio_batch.append(wav)

merge_time = (datetime.now() - ts).total_seconds()
audio_batch.append(np.frombuffer(input.audio, dtype=np.float32))

ts = datetime.now()
segments = self.transcribe(audio_batch)
Expand All @@ -224,13 +201,14 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
BatchOutput(
text=text,
batched_count=len(inputs),
merge_audio_time=merge_time,
merge_audio_time=0,
transcribe_time=transcribe_time,
restore_time=0,
)
for text in segments
]


"""
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
Expand Down Expand Up @@ -298,4 +276,4 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
)
for text in segments
]
"""
"""
122 changes: 122 additions & 0 deletions src/stream_transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from faster_whisper import decode_audio
from faster_whisper.vad import get_speech_timestamps, collect_chunks, VadOptions

import numpy as np
import io


def len2secs(x, sample_rate=16000):
return x / sample_rate


def buf2secs(buf, sample_rate=16000):
return len2secs(len(buf), sample_rate)


def secs2len(secs, sample_rate=16000):
return round(secs * 16000)


def split_speech_timestamps(speech_timestamps, buffered, split_time):
max_offset = secs2len(split_time)
buffered = secs2len(buffered)

timestamps = []
while speech_timestamps:
timestamps.append([])
while speech_timestamps and speech_timestamps[0]["end"] < max_offset - buffered:
timestamps[-1].append(speech_timestamps.pop(0))

max_offset += secs2len(split_time)
buffered = 0

return timestamps


class StreamTranscriber:
def __init__(self):

self._raw_stream_data = bytes()
self._raw_stream_data_duration = 0

self._vad_detected_offset = 0
self._speech_audio_buffers = []
self._speech_timestamps = []

self._vad_options = VadOptions(
min_speech_duration_ms=60, min_silence_duration_ms=60
)

def consume(self, stream_data: bytes):
self._raw_stream_data += stream_data
try:
raw_audio_buffer = decode_audio(io.BytesIO(self._raw_stream_data))
raw_audio_buffer = raw_audio_buffer[self._vad_detected_offset :]
except:
return

self._raw_stream_data_duration = buf2secs(raw_audio_buffer)

speech_timestamps = get_speech_timestamps(
raw_audio_buffer, vad_options=self._vad_options
)

if not speech_timestamps:
return

# remove the speech chunks which probably are not ended
while (
speech_timestamps
and speech_timestamps[-1]["end"]
> len(raw_audio_buffer) - self._vad_options.min_silence_duration_ms * 16
):
del speech_timestamps[-1]

if not speech_timestamps:
return

self._vad_detected_offset += speech_timestamps[-1]["end"]

buffered = 0
if self._speech_audio_buffers:
buffered = buf2secs(self._speech_audio_buffers[-1])

print(speech_timestamps)

speech_timestamps = split_speech_timestamps(
speech_timestamps,
buffered,
5,
)

print(speech_timestamps)

for chunks in speech_timestamps:
speech = collect_chunks(raw_audio_buffer, chunks)
if (
not self._speech_audio_buffers
or buf2secs(self._speech_audio_buffers[-1]) > 5
):
self._speech_audio_buffers.append(speech)
else:
self._speech_audio_buffers[-1] = np.append(
self._speech_audio_buffers[-1], speech
)

[print(buf2secs(x)) for x in self._speech_audio_buffers]

print(len2secs(self._raw_stream_data_duration), self._vad_detected_offset)

def should_transcribe(self):
if not self._speech_audio_buffers:
return False
if buf2secs(self._speech_audio_buffers[0]) > 1:
return True
if self._raw_stream_data_duration > 3:
return True
return False

def get_speech_audio(self) -> bytes:
assert self.should_transcribe()

return self._speech_audio_buffers.pop(0).tobytes()
76 changes: 44 additions & 32 deletions src/stt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import bentoml
from runners.audio_transcriber import (
AudioTranscriber,
BatchableAudioTranscriber,
BatchInput,
)
Expand All @@ -17,6 +16,7 @@
from utils.service_key.brave_service_key import check_stt_request

import utils.ipc as ipc
from stream_transcriber import StreamTranscriber

runner_audio_transcriber = bentoml.Runner(
BatchableAudioTranscriber,
Expand Down Expand Up @@ -60,48 +60,60 @@ async def handleUpstream(
try:
mic_data = bytes()
text = ""
stream = StreamTranscriber()
async with ipc.client.Publisher(pair) as pipe:
try:
async for chunk in request.stream():
if len(chunk) == 0:
break
mic_data += chunk
process_time = datetime.now()
transciption = await runner_audio_transcriber.async_run(
[BatchInput(audio=mic_data, lang=lang, pair=pair)]
)
process_time = (datetime.now() - process_time).total_seconds()

out = transciption[0]
print(
pair,
" : ",
out.batched_count,
"",
out.merge_audio_time,
" ",
out.transcribe_time,
" ",
out.restore_time,
)

if out.text:
await pipe.push(
ipc.messages.Text(
out.text,
False,
len(mic_data),
out.merge_audio_time
+ out.transcribe_time
+ out.restore_time,
process_time,
)
stream.consume(chunk)

while stream.should_transcribe():
process_time = datetime.now()
transciption = await runner_audio_transcriber.async_run(
[
BatchInput(
audio=stream.get_speech_audio(),
lang=lang,
pair=pair,
)
]
)
process_time = (datetime.now() - process_time).total_seconds()

out = transciption[0]
print(
pair,
" : ",
out.batched_count,
"",
out.merge_audio_time,
" ",
out.transcribe_time,
" ",
out.restore_time,
)

if out.text:
text += out.text.lower() + " "
await pipe.push(
ipc.messages.Text(
text,
False,
len(mic_data),
out.merge_audio_time
+ out.transcribe_time
+ out.restore_time,
process_time,
)
)

finally:
if text:
await pipe.push(ipc.messages.Text(text, True))

except Exception as e:
raise
return JSONResponse(
content=jsonable_encoder({"status": "exception", "exception": str(e)})
)
Expand Down

0 comments on commit 2741d3a

Please sign in to comment.