diff --git a/src/runners/audio_transcriber.py b/src/runners/audio_transcriber.py index 7f54112..aabd21f 100644 --- a/src/runners/audio_transcriber.py +++ b/src/runners/audio_transcriber.py @@ -39,11 +39,11 @@ def transcribe_audio(self, audio, lang): return {"text": text} -from pydantic import BaseModel import numpy as np import io from datetime import datetime from faster_whisper.vad import get_speech_timestamps, collect_chunks +from pydantic import BaseModel class BatchInput(BaseModel): @@ -60,20 +60,6 @@ class BatchOutput(BaseModel): restore_time: float -class BatchItem(BaseModel): - start_time: float - end_time: float - chunks_count: int - transcription: str = "" - - def add(self, word): - if ( - self.chunks_count > 0 - and word.start >= self.start_time - and word.end <= self.end_time - ): - self.transcription += word.word - """ class BatchableAudioTranscriber(bentoml.Runnable): SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") @@ -177,11 +163,14 @@ class BatchableAudioTranscriber(bentoml.Runnable): SUPPORTS_CPU_MULTI_THREADING = True def __init__(self): + pass self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = Wav2Vec2Processor.from_pretrained( + # "facebook/wav2vec2-base-960h" "facebook/wav2vec2-large-960h-lv60-self" ) self.model = Wav2Vec2ForCTC.from_pretrained( + # "facebook/wav2vec2-base-960h" "facebook/wav2vec2-large-960h-lv60-self" ).cuda() @@ -194,27 +183,15 @@ def transcribe(self, audios): logits = self.model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcriptions = self.processor.batch_decode(predicted_ids) - print(transcriptions) return transcriptions @bentoml.Runnable.method(batchable=True) def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: - result = [] - # merging audio - ts = datetime.now() audio_batch = [] for input in inputs: - wav = decode_audio(io.BytesIO(input.audio)) - chunks = get_speech_timestamps(wav) - if len(chunks) == 0: - audio_batch.append(np.zeros(16000, dtype=np.float32)) - else: - wav = collect_chunks(wav, chunks=chunks) - audio_batch.append(wav) - - merge_time = (datetime.now() - ts).total_seconds() + audio_batch.append(np.frombuffer(input.audio, dtype=np.float32)) ts = datetime.now() segments = self.transcribe(audio_batch) @@ -224,13 +201,14 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: BatchOutput( text=text, batched_count=len(inputs), - merge_audio_time=merge_time, + merge_audio_time=0, transcribe_time=transcribe_time, restore_time=0, ) for text in segments ] + """ from transformers import WhisperProcessor, WhisperForConditionalGeneration import torch @@ -298,4 +276,4 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: ) for text in segments ] -""" \ No newline at end of file +""" diff --git a/src/stream_transcriber.py b/src/stream_transcriber.py new file mode 100644 index 0000000..ecacd44 --- /dev/null +++ b/src/stream_transcriber.py @@ -0,0 +1,122 @@ +from faster_whisper import decode_audio +from faster_whisper.vad import get_speech_timestamps, collect_chunks, VadOptions + +import numpy as np +import io + + +def len2secs(x, sample_rate=16000): + return x / sample_rate + + +def buf2secs(buf, sample_rate=16000): + return len2secs(len(buf), sample_rate) + + +def secs2len(secs, sample_rate=16000): + return round(secs * 16000) + + +def split_speech_timestamps(speech_timestamps, buffered, split_time): + max_offset = secs2len(split_time) + buffered = secs2len(buffered) + + timestamps = [] + while speech_timestamps: + timestamps.append([]) + while speech_timestamps and speech_timestamps[0]["end"] < max_offset - buffered: + timestamps[-1].append(speech_timestamps.pop(0)) + + max_offset += secs2len(split_time) + buffered = 0 + + return timestamps + + +class StreamTranscriber: + def __init__(self): + + self._raw_stream_data = bytes() + self._raw_stream_data_duration = 0 + + self._vad_detected_offset = 0 + self._speech_audio_buffers = [] + self._speech_timestamps = [] + + self._vad_options = VadOptions( + min_speech_duration_ms=60, min_silence_duration_ms=60 + ) + + def consume(self, stream_data: bytes): + self._raw_stream_data += stream_data + try: + raw_audio_buffer = decode_audio(io.BytesIO(self._raw_stream_data)) + raw_audio_buffer = raw_audio_buffer[self._vad_detected_offset :] + except: + return + + self._raw_stream_data_duration = buf2secs(raw_audio_buffer) + + speech_timestamps = get_speech_timestamps( + raw_audio_buffer, vad_options=self._vad_options + ) + + if not speech_timestamps: + return + + # remove the speech chunks which probably are not ended + while ( + speech_timestamps + and speech_timestamps[-1]["end"] + > len(raw_audio_buffer) - self._vad_options.min_silence_duration_ms * 16 + ): + del speech_timestamps[-1] + + if not speech_timestamps: + return + + self._vad_detected_offset += speech_timestamps[-1]["end"] + + buffered = 0 + if self._speech_audio_buffers: + buffered = buf2secs(self._speech_audio_buffers[-1]) + + print(speech_timestamps) + + speech_timestamps = split_speech_timestamps( + speech_timestamps, + buffered, + 5, + ) + + print(speech_timestamps) + + for chunks in speech_timestamps: + speech = collect_chunks(raw_audio_buffer, chunks) + if ( + not self._speech_audio_buffers + or buf2secs(self._speech_audio_buffers[-1]) > 5 + ): + self._speech_audio_buffers.append(speech) + else: + self._speech_audio_buffers[-1] = np.append( + self._speech_audio_buffers[-1], speech + ) + + [print(buf2secs(x)) for x in self._speech_audio_buffers] + + print(len2secs(self._raw_stream_data_duration), self._vad_detected_offset) + + def should_transcribe(self): + if not self._speech_audio_buffers: + return False + if buf2secs(self._speech_audio_buffers[0]) > 1: + return True + if self._raw_stream_data_duration > 3: + return True + return False + + def get_speech_audio(self) -> bytes: + assert self.should_transcribe() + + return self._speech_audio_buffers.pop(0).tobytes() diff --git a/src/stt_api.py b/src/stt_api.py index f1f036b..fe1a2ac 100644 --- a/src/stt_api.py +++ b/src/stt_api.py @@ -4,7 +4,6 @@ import bentoml from runners.audio_transcriber import ( - AudioTranscriber, BatchableAudioTranscriber, BatchInput, ) @@ -17,6 +16,7 @@ from utils.service_key.brave_service_key import check_stt_request import utils.ipc as ipc +from stream_transcriber import StreamTranscriber runner_audio_transcriber = bentoml.Runner( BatchableAudioTranscriber, @@ -60,48 +60,60 @@ async def handleUpstream( try: mic_data = bytes() text = "" + stream = StreamTranscriber() async with ipc.client.Publisher(pair) as pipe: try: async for chunk in request.stream(): if len(chunk) == 0: break - mic_data += chunk - process_time = datetime.now() - transciption = await runner_audio_transcriber.async_run( - [BatchInput(audio=mic_data, lang=lang, pair=pair)] - ) - process_time = (datetime.now() - process_time).total_seconds() - - out = transciption[0] - print( - pair, - " : ", - out.batched_count, - "", - out.merge_audio_time, - " ", - out.transcribe_time, - " ", - out.restore_time, - ) - - if out.text: - await pipe.push( - ipc.messages.Text( - out.text, - False, - len(mic_data), - out.merge_audio_time - + out.transcribe_time - + out.restore_time, - process_time, - ) + stream.consume(chunk) + + while stream.should_transcribe(): + process_time = datetime.now() + transciption = await runner_audio_transcriber.async_run( + [ + BatchInput( + audio=stream.get_speech_audio(), + lang=lang, + pair=pair, + ) + ] + ) + process_time = (datetime.now() - process_time).total_seconds() + + out = transciption[0] + print( + pair, + " : ", + out.batched_count, + "", + out.merge_audio_time, + " ", + out.transcribe_time, + " ", + out.restore_time, ) + + if out.text: + text += out.text.lower() + " " + await pipe.push( + ipc.messages.Text( + text, + False, + len(mic_data), + out.merge_audio_time + + out.transcribe_time + + out.restore_time, + process_time, + ) + ) + finally: if text: await pipe.push(ipc.messages.Text(text, True)) except Exception as e: + raise return JSONResponse( content=jsonable_encoder({"status": "exception", "exception": str(e)}) )