Skip to content

Commit

Permalink
language param.
Browse files Browse the repository at this point in the history
  • Loading branch information
boocmp committed Sep 30, 2024
1 parent 751908c commit 66fc806
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ RUN pip install .

EXPOSE 3000

CMD [ "python3", "-m", "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "stt:app", "--workers", "1", "-b", "0.0.0.0:3000"]
CMD [ "python3", "-m", "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "stt:app", "--workers", "4", "-b", "0.0.0.0:3000"]
18 changes: 13 additions & 5 deletions src/runners/audio_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

logger = logging.getLogger(__name__)


class BatchOutput(BaseModel):
text: str
batched_count: int
Expand All @@ -23,9 +24,11 @@ class BatchOutput(BaseModel):


class WhisperHFRunnable(Runnable):

def __init__(self, model_id: str = "openai/whisper-tiny"):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
torch_dtype = torch.float16 if torch.cuda.is_available(
) else torch.float32

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
Expand All @@ -46,9 +49,15 @@ def __init__(self, model_id: str = "openai/whisper-tiny"):
device=device,
)

def forward(self, data: List[bytes]) -> List[BatchOutput]:
def forward(self, data: List[any]) -> List[BatchOutput]:
start = time.time()
result = self.pipe(data, batch_size=len(data))

audios = [d["raw"] for d in data]
langs = [d["lang"] for d in data]

result = self.pipe(audios,
batch_size=len(data),
generate_kwargs={"language": langs})
logger.debug(result)
transcribe_time = time.time() - start
no_punctuation = str.maketrans('', '', string.punctuation)
Expand All @@ -59,6 +68,5 @@ def forward(self, data: List[bytes]) -> List[BatchOutput]:
merge_audio_time=0,
transcribe_time=transcribe_time,
restore_time=0,
)
for r in result
) for r in result
]
7 changes: 3 additions & 4 deletions src/stream_transcriber.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch

from faster_whisper import decode_audio
from faster_whisper.vad import collect_chunks

from silero_vad import (get_speech_timestamps, load_silero_vad, collect_chunks)
from silero_vad import (get_speech_timestamps, load_silero_vad)

import numpy as np
import io
Expand Down Expand Up @@ -103,9 +104,7 @@ async def consume(self, stream_data: bytes):
)

for chunks in speech_timestamps:
speech = collect_chunks(
chunks, torch.tensor(raw_audio_buffer,
dtype=torch.float32)).numpy()
speech = collect_chunks(raw_audio_buffer, chunks)
if (not self._speech_audio_buffers
or buf2secs(self._speech_audio_buffers[-1]) > 5):
self._speech_audio_buffers.append(speech)
Expand Down
1 change: 1 addition & 0 deletions src/stt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def handleUpstream(
"raw": np.frombuffer(
stream.get_speech_audio(), dtype=np.float32
),
"lang": lang[:2],
"sampling_rate": 16000,
}
]
Expand Down

0 comments on commit 66fc806

Please sign in to comment.