Skip to content

Commit

Permalink
wav2vec2 backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
boocmp committed Aug 19, 2024
1 parent 40befd2 commit 343cae5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
2 changes: 2 additions & 0 deletions env/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ pydantic
pydantic-settings
six
msgspec
transformers
torch
67 changes: 67 additions & 0 deletions src/runners/audio_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def add(self, word):
self.transcription += word.word


"""
class BatchableAudioTranscriber(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
Expand Down Expand Up @@ -163,3 +164,69 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
)
return result
"""

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from itertools import groupby


class BatchableAudioTranscriber(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True

def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-base-960h"
)
self.model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h"
).cuda()

def transcribe(self, audios):
input_values = self.processor(
audios, return_tensors="pt", sampling_rate=16000
).input_values.cuda()

with torch.no_grad():
logits = self.model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcriptions = self.processor.batch_decode(predicted_ids)
print(transcriptions)

return transcriptions

@bentoml.Runnable.method(batchable=True)
def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
result = []

# merging audio
ts = datetime.now()
audio_batch = []
for input in inputs:
wav = decode_audio(io.BytesIO(input.audio))
chunks = get_speech_timestamps(wav)
if len(chunks) == 0:
audio_batch.append(np.zeros(16000, dtype=np.float32))
else:
wav = collect_chunks(wav, chunks=chunks)
audio_batch.append(wav)

merge_time = (datetime.now() - ts).total_seconds()

ts = datetime.now()
segments = self.transcribe(audio_batch)
transcribe_time = (datetime.now() - ts).total_seconds()

return [
BatchOutput(
text=text,
batched_count=len(inputs),
merge_audio_time=merge_time,
transcribe_time=transcribe_time,
restore_time=0,
)
for text in segments
]

0 comments on commit 343cae5

Please sign in to comment.