diff --git a/Dockerfile b/Dockerfile index b57857e..073464f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,4 +17,4 @@ RUN pip install . EXPOSE 3000 -CMD [ "python3", "-m", "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "stt:app", "--workers", "1", "-b", "0.0.0.0:3000"] +CMD [ "python3", "-m", "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "stt:app", "--workers", "4", "-b", "0.0.0.0:3000"] diff --git a/src/runners/audio_transcriber.py b/src/runners/audio_transcriber.py index 0310a6e..ce72b4a 100644 --- a/src/runners/audio_transcriber.py +++ b/src/runners/audio_transcriber.py @@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) + class BatchOutput(BaseModel): text: str batched_count: int @@ -23,9 +24,11 @@ class BatchOutput(BaseModel): class WhisperHFRunnable(Runnable): + def __init__(self, model_id: str = "openai/whisper-tiny"): device = "cuda:0" if torch.cuda.is_available() else "cpu" - torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + torch_dtype = torch.float16 if torch.cuda.is_available( + ) else torch.float32 model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, @@ -46,9 +49,15 @@ def __init__(self, model_id: str = "openai/whisper-tiny"): device=device, ) - def forward(self, data: List[bytes]) -> List[BatchOutput]: + def forward(self, data: List[any]) -> List[BatchOutput]: start = time.time() - result = self.pipe(data, batch_size=len(data)) + + audios = [d["raw"] for d in data] + langs = [d["lang"] for d in data] + + result = self.pipe(audios, + batch_size=len(data), + generate_kwargs={"language": langs}) logger.debug(result) transcribe_time = time.time() - start no_punctuation = str.maketrans('', '', string.punctuation) @@ -59,6 +68,5 @@ def forward(self, data: List[bytes]) -> List[BatchOutput]: merge_audio_time=0, transcribe_time=transcribe_time, restore_time=0, - ) - for r in result + ) for r in result ] diff --git a/src/stream_transcriber.py b/src/stream_transcriber.py index 9a1b764..3450134 100644 --- a/src/stream_transcriber.py +++ b/src/stream_transcriber.py @@ -1,8 +1,9 @@ import torch from faster_whisper import decode_audio +from faster_whisper.vad import collect_chunks -from silero_vad import (get_speech_timestamps, load_silero_vad, collect_chunks) +from silero_vad import (get_speech_timestamps, load_silero_vad) import numpy as np import io @@ -103,9 +104,7 @@ async def consume(self, stream_data: bytes): ) for chunks in speech_timestamps: - speech = collect_chunks( - chunks, torch.tensor(raw_audio_buffer, - dtype=torch.float32)).numpy() + speech = collect_chunks(raw_audio_buffer, chunks) if (not self._speech_audio_buffers or buf2secs(self._speech_audio_buffers[-1]) > 5): self._speech_audio_buffers.append(speech) diff --git a/src/stt_api.py b/src/stt_api.py index 34945dd..6229217 100644 --- a/src/stt_api.py +++ b/src/stt_api.py @@ -86,6 +86,7 @@ async def handleUpstream( "raw": np.frombuffer( stream.get_speech_audio(), dtype=np.float32 ), + "lang": lang[:2], "sampling_rate": 16000, } ]