[Voxtral] Fix speech transcription api (#31388)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
Patrick von Platen
2026-01-08 12:34:19 +02:00
committed by GitHub
parent 2972a05473
commit 18d4e481d0
5 changed files with 114 additions and 27 deletions

View File

@@ -17,10 +17,11 @@ class SpeechToTextConfig:
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int = 30
max_audio_clip_s: int | None = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
`allow_audio_chunking` evaluates to True, otherwise it will be rejected.
`None` means audio duration can be unlimited and won't be chunked."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when

View File

@@ -477,7 +477,15 @@ class OpenAISpeechToText(OpenAIServing):
}
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
text = ""
chunk_size_in_s = self.asr_config.max_audio_clip_s
if chunk_size_in_s is None:
assert len(list_result_generator) == 1, (
"`max_audio_clip_s` is set to None, audio cannot be chunked"
)
for idx, result_generator in enumerate(list_result_generator):
start_time = (
float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
)
async for op in result_generator:
if request.response_format == "verbose_json":
segments: list[SpeechToTextSegment] = (
@@ -485,7 +493,7 @@ class OpenAISpeechToText(OpenAIServing):
tokens=tuple(op.outputs[0].token_ids),
segment_class=segment_class,
request=request,
start_time=idx * self.asr_config.max_audio_clip_s,
start_time=start_time,
)
)
@@ -653,6 +661,10 @@ class OpenAISpeechToText(OpenAIServing):
def _split_audio(
self, audio_data: np.ndarray, sample_rate: int
) -> list[np.ndarray]:
assert self.asr_config.max_audio_clip_s is not None, (
f"{self.asr_config.max_audio_clip_s=} cannot be None to"
" split audio into chunks."
)
chunk_size = sample_rate * self.asr_config.max_audio_clip_s
overlap_size = sample_rate * self.asr_config.overlap_chunk_second
chunks = []

View File

@@ -17,7 +17,11 @@ from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChu
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
from mistral_common.tokens.tokenizers.audio import (
Audio,
AudioEncoder,
TranscriptionFormat,
)
from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput
@@ -157,13 +161,17 @@ class VoxtralProcessorAdapter:
# pad if necessary
# TODO(Patrick) - remove once mistral-common is bumped
sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
else:
audio = self._audio_processor.pad(audio, self.sampling_rate)
if (
self._audio_processor.audio_config.transcription_format
!= TranscriptionFormat.STREAMING
):
sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
else:
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id

View File

@@ -3,10 +3,19 @@
import math
from collections.abc import Mapping
from typing import Literal, cast
import numpy as np
import torch
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
StreamingMode,
TranscriptionRequest,
)
from mistral_common.tokens.tokenizers.audio import Audio
from vllm.config.vllm import VllmConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.voxtral import (
@@ -27,6 +36,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from .utils import (
_flatten_embeddings,
@@ -205,13 +215,17 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
"For streaming you must provide an audio input at every step."
)
multiple_of = self.audio_config.raw_audio_length_per_tok
assert all(
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs
), (
f"Every input audio waveform has to be a multiple of {multiple_of}, but"
f" one is {this_audio} with {(this_audio / multiple_of)=}."
)
def _truncate_left(
sample: torch.Tensor, mult_of: int, pos: int
) -> torch.Tensor:
assert pos in [0, 1], pos
if (ctx := sample.shape[pos] % mult_of) != 0:
sample = sample[ctx:] if pos == 0 else sample[:, ctx:]
assert sample.shape[pos] > 0, (
f"Sample is empty after truncation with ctx {ctx}"
)
return sample
mel_features = [
self.whisper_encoder.compute_whisper_melspec(audio).to(
@@ -219,11 +233,16 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
)
for audio in audio_inputs
]
# we truncate the left most mel feature
# if the sequence length in impair
mel_features = [_truncate_left(mel, 2, 1) for mel in mel_features]
seq_lens = [mel.shape[1] for mel in mel_features]
# [total_num_20ms_frames, hidden_size]
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
mel_features
)[0]
)
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
audio_embeddings_per_sample = audio_embeddings.split(
[s // conv_stride for s in seq_lens], dim=0
@@ -231,13 +250,55 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
# audio_embeddings per sample need to be divisible by 4
pool_size = self.config.audio_config.block_pool_size
assert all(
(this_shape := sample.shape[0]) % pool_size == 0
audio_embeddings_per_sample = [
_truncate_left(sample, pool_size, 0)
for sample in audio_embeddings_per_sample
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}."
]
audio_embeddings_per_sample = [
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
for e in audio_embeddings_per_sample
]
return audio_embeddings_per_sample
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
tokenizer = cached_tokenizer_from_config(model_config)
audio_config = tokenizer.instruct.audio_encoder.audio_config
sample_rate = audio_config.sampling_rate
return SpeechToTextConfig(
max_audio_clip_s=None, # only limited by memory
sample_rate=sample_rate,
min_energy_split_window_size=None,
)
@classmethod
# for speech-to-text transcription
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest(
model=model_config.model,
audio=RawAudio.from_audio(audio),
language=language,
streaming=StreamingMode.OFFLINE,
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)

View File

@@ -469,8 +469,10 @@ class WhisperEncoder(nn.Module):
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
is_causal = getattr(config, "is_causal", False)
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
self.is_causal = getattr(config, "is_causal", False)
Conv1d = (
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
)
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
@@ -485,7 +487,7 @@ class WhisperEncoder(nn.Module):
)
self.layer_norm = nn.LayerNorm(config.d_model)
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
raise ValueError(
"Only NOPE position embeddings are supported "
f"for causal models, but got {self.pos_embed_type}"
@@ -536,8 +538,11 @@ class WhisperEncoder(nn.Module):
hidden_states.append(embeds)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
if input_is_batched:
if input_is_batched or self.is_causal:
# Models using WhisperEncoder may handle batching internally.
# If WhisperEncoder is causal, sequences
# are not padded to have identical seq length (T)
# => concat over feature dim
hidden_states = torch.cat(hidden_states)
else:
hidden_states = torch.stack(hidden_states, dim=0)