[Voxtral] Fix speech transcription api (#31388)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: bk-201 <joy25810@foxmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: bk-201 <joy25810@foxmail.com> Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: Anexdeus <5142168@mail.ru> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2972a05473
commit
18d4e481d0
@@ -17,10 +17,11 @@ class SpeechToTextConfig:
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int = 30
|
||||
max_audio_clip_s: int | None = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected.
|
||||
`None` means audio duration can be unlimited and won't be chunked."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
|
||||
@@ -477,7 +477,15 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
}
|
||||
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
||||
text = ""
|
||||
chunk_size_in_s = self.asr_config.max_audio_clip_s
|
||||
if chunk_size_in_s is None:
|
||||
assert len(list_result_generator) == 1, (
|
||||
"`max_audio_clip_s` is set to None, audio cannot be chunked"
|
||||
)
|
||||
for idx, result_generator in enumerate(list_result_generator):
|
||||
start_time = (
|
||||
float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
|
||||
)
|
||||
async for op in result_generator:
|
||||
if request.response_format == "verbose_json":
|
||||
segments: list[SpeechToTextSegment] = (
|
||||
@@ -485,7 +493,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
tokens=tuple(op.outputs[0].token_ids),
|
||||
segment_class=segment_class,
|
||||
request=request,
|
||||
start_time=idx * self.asr_config.max_audio_clip_s,
|
||||
start_time=start_time,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -653,6 +661,10 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
def _split_audio(
|
||||
self, audio_data: np.ndarray, sample_rate: int
|
||||
) -> list[np.ndarray]:
|
||||
assert self.asr_config.max_audio_clip_s is not None, (
|
||||
f"{self.asr_config.max_audio_clip_s=} cannot be None to"
|
||||
" split audio into chunks."
|
||||
)
|
||||
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
|
||||
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
|
||||
chunks = []
|
||||
|
||||
@@ -17,7 +17,11 @@ from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChu
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
||||
from mistral_common.tokens.tokenizers.audio import (
|
||||
Audio,
|
||||
AudioEncoder,
|
||||
TranscriptionFormat,
|
||||
)
|
||||
from transformers import BatchFeature, TensorType, WhisperConfig
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
@@ -157,13 +161,17 @@ class VoxtralProcessorAdapter:
|
||||
|
||||
# pad if necessary
|
||||
# TODO(Patrick) - remove once mistral-common is bumped
|
||||
sig = inspect.signature(self._audio_processor.pad)
|
||||
if "is_online_streaming" in sig.parameters:
|
||||
audio = self._audio_processor.pad(
|
||||
audio, self.sampling_rate, is_online_streaming=False
|
||||
)
|
||||
else:
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
if (
|
||||
self._audio_processor.audio_config.transcription_format
|
||||
!= TranscriptionFormat.STREAMING
|
||||
):
|
||||
sig = inspect.signature(self._audio_processor.pad)
|
||||
if "is_online_streaming" in sig.parameters:
|
||||
audio = self._audio_processor.pad(
|
||||
audio, self.sampling_rate, is_online_streaming=False
|
||||
)
|
||||
else:
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
|
||||
audio_tokens = [self.begin_audio_token_id] + [
|
||||
self.audio_token_id
|
||||
|
||||
@@ -3,10 +3,19 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mistral_common.protocol.instruct.chunk import RawAudio
|
||||
from mistral_common.protocol.transcription.request import (
|
||||
StreamingMode,
|
||||
TranscriptionRequest,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.audio import Audio
|
||||
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.voxtral import (
|
||||
@@ -27,6 +36,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
|
||||
from .utils import (
|
||||
_flatten_embeddings,
|
||||
@@ -205,13 +215,17 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
"For streaming you must provide an audio input at every step."
|
||||
)
|
||||
|
||||
multiple_of = self.audio_config.raw_audio_length_per_tok
|
||||
assert all(
|
||||
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs
|
||||
), (
|
||||
f"Every input audio waveform has to be a multiple of {multiple_of}, but"
|
||||
f" one is {this_audio} with {(this_audio / multiple_of)=}."
|
||||
)
|
||||
def _truncate_left(
|
||||
sample: torch.Tensor, mult_of: int, pos: int
|
||||
) -> torch.Tensor:
|
||||
assert pos in [0, 1], pos
|
||||
if (ctx := sample.shape[pos] % mult_of) != 0:
|
||||
sample = sample[ctx:] if pos == 0 else sample[:, ctx:]
|
||||
assert sample.shape[pos] > 0, (
|
||||
f"Sample is empty after truncation with ctx {ctx}"
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
mel_features = [
|
||||
self.whisper_encoder.compute_whisper_melspec(audio).to(
|
||||
@@ -219,11 +233,16 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
)
|
||||
for audio in audio_inputs
|
||||
]
|
||||
|
||||
# we truncate the left most mel feature
|
||||
# if the sequence length in impair
|
||||
mel_features = [_truncate_left(mel, 2, 1) for mel in mel_features]
|
||||
|
||||
seq_lens = [mel.shape[1] for mel in mel_features]
|
||||
# [total_num_20ms_frames, hidden_size]
|
||||
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
|
||||
mel_features
|
||||
)[0]
|
||||
)
|
||||
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
|
||||
audio_embeddings_per_sample = audio_embeddings.split(
|
||||
[s // conv_stride for s in seq_lens], dim=0
|
||||
@@ -231,13 +250,55 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
|
||||
# audio_embeddings per sample need to be divisible by 4
|
||||
pool_size = self.config.audio_config.block_pool_size
|
||||
assert all(
|
||||
(this_shape := sample.shape[0]) % pool_size == 0
|
||||
|
||||
audio_embeddings_per_sample = [
|
||||
_truncate_left(sample, pool_size, 0)
|
||||
for sample in audio_embeddings_per_sample
|
||||
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}."
|
||||
]
|
||||
|
||||
audio_embeddings_per_sample = [
|
||||
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
|
||||
for e in audio_embeddings_per_sample
|
||||
]
|
||||
return audio_embeddings_per_sample
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
||||
sample_rate = audio_config.sampling_rate
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=None, # only limited by memory
|
||||
sample_rate=sample_rate,
|
||||
min_energy_split_window_size=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# for speech-to-text transcription
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
|
||||
|
||||
req = TranscriptionRequest(
|
||||
model=model_config.model,
|
||||
audio=RawAudio.from_audio(audio),
|
||||
language=language,
|
||||
streaming=StreamingMode.OFFLINE,
|
||||
)
|
||||
|
||||
tokenized = tokenizer.instruct.encode_transcription(req)
|
||||
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
|
||||
prompts_dict = {"multi_modal_data": {"audio": audio}}
|
||||
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@@ -469,8 +469,10 @@ class WhisperEncoder(nn.Module):
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
|
||||
self.is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = (
|
||||
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
|
||||
)
|
||||
|
||||
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||
@@ -485,7 +487,7 @@ class WhisperEncoder(nn.Module):
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
raise ValueError(
|
||||
"Only NOPE position embeddings are supported "
|
||||
f"for causal models, but got {self.pos_embed_type}"
|
||||
@@ -536,8 +538,11 @@ class WhisperEncoder(nn.Module):
|
||||
hidden_states.append(embeds)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
if input_is_batched:
|
||||
if input_is_batched or self.is_causal:
|
||||
# Models using WhisperEncoder may handle batching internally.
|
||||
# If WhisperEncoder is causal, sequences
|
||||
# are not padded to have identical seq length (T)
|
||||
# => concat over feature dim
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user