diff --git a/env/python/requirements.txt b/env/python/requirements.txt index 5dab702..4eea27a 100644 --- a/env/python/requirements.txt +++ b/env/python/requirements.txt @@ -7,3 +7,5 @@ pydantic pydantic-settings six msgspec +transformers +torch \ No newline at end of file diff --git a/src/runners/audio_transcriber.py b/src/runners/audio_transcriber.py index 52cbc06..fa6d7b5 100644 --- a/src/runners/audio_transcriber.py +++ b/src/runners/audio_transcriber.py @@ -75,6 +75,7 @@ def add(self, word): self.transcription += word.word +""" class BatchableAudioTranscriber(bentoml.Runnable): SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") SUPPORTS_CPU_MULTI_THREADING = True @@ -163,3 +164,69 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: ) return result + +""" + +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor +import torch +from itertools import groupby + + +class BatchableAudioTranscriber(bentoml.Runnable): + SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") + SUPPORTS_CPU_MULTI_THREADING = True + + def __init__(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.processor = Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-base-960h" + ) + self.model = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-base-960h" + ).cuda() + + def transcribe(self, audios): + input_values = self.processor( + audios, return_tensors="pt", sampling_rate=16000 + ).input_values.cuda() + + with torch.no_grad(): + logits = self.model(input_values).logits + predicted_ids = torch.argmax(logits, dim=-1) + transcriptions = self.processor.batch_decode(predicted_ids) + print(transcriptions) + + return transcriptions + + @bentoml.Runnable.method(batchable=True) + def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: + result = [] + + # merging audio + ts = datetime.now() + audio_batch = [] + for input in inputs: + wav = decode_audio(io.BytesIO(input.audio)) + chunks = get_speech_timestamps(wav) + if len(chunks) == 0: + audio_batch.append(np.zeros(16000, dtype=np.float32)) + else: + wav = collect_chunks(wav, chunks=chunks) + audio_batch.append(wav) + + merge_time = (datetime.now() - ts).total_seconds() + + ts = datetime.now() + segments = self.transcribe(audio_batch) + transcribe_time = (datetime.now() - ts).total_seconds() + + return [ + BatchOutput( + text=text, + batched_count=len(inputs), + merge_audio_time=merge_time, + transcribe_time=transcribe_time, + restore_time=0, + ) + for text in segments + ]