[Voxtral] Fix speech transcription api (#31388)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
Patrick von Platen
2026-01-08 12:34:19 +02:00
committed by GitHub
parent 2972a05473
commit 18d4e481d0
5 changed files with 114 additions and 27 deletions

View File

@@ -17,10 +17,11 @@ class SpeechToTextConfig:
16kHz audio input. The input audio will be automatically resampled to this 16kHz audio input. The input audio will be automatically resampled to this
rate before processing.""" rate before processing."""
max_audio_clip_s: int = 30 max_audio_clip_s: int | None = 30
"""Maximum duration in seconds for a single audio clip without chunking. """Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" `allow_audio_chunking` evaluates to True, otherwise it will be rejected.
`None` means audio duration can be unlimited and won't be chunked."""
overlap_chunk_second: int = 1 overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when """Overlap duration in seconds between consecutive audio chunks when

View File

@@ -477,7 +477,15 @@ class OpenAISpeechToText(OpenAIServing):
} }
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type] segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
text = "" text = ""
chunk_size_in_s = self.asr_config.max_audio_clip_s
if chunk_size_in_s is None:
assert len(list_result_generator) == 1, (
"`max_audio_clip_s` is set to None, audio cannot be chunked"
)
for idx, result_generator in enumerate(list_result_generator): for idx, result_generator in enumerate(list_result_generator):
start_time = (
float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
)
async for op in result_generator: async for op in result_generator:
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
segments: list[SpeechToTextSegment] = ( segments: list[SpeechToTextSegment] = (
@@ -485,7 +493,7 @@ class OpenAISpeechToText(OpenAIServing):
tokens=tuple(op.outputs[0].token_ids), tokens=tuple(op.outputs[0].token_ids),
segment_class=segment_class, segment_class=segment_class,
request=request, request=request,
start_time=idx * self.asr_config.max_audio_clip_s, start_time=start_time,
) )
) )
@@ -653,6 +661,10 @@ class OpenAISpeechToText(OpenAIServing):
def _split_audio( def _split_audio(
self, audio_data: np.ndarray, sample_rate: int self, audio_data: np.ndarray, sample_rate: int
) -> list[np.ndarray]: ) -> list[np.ndarray]:
assert self.asr_config.max_audio_clip_s is not None, (
f"{self.asr_config.max_audio_clip_s=} cannot be None to"
" split audio into chunks."
)
chunk_size = sample_rate * self.asr_config.max_audio_clip_s chunk_size = sample_rate * self.asr_config.max_audio_clip_s
overlap_size = sample_rate * self.asr_config.overlap_chunk_second overlap_size = sample_rate * self.asr_config.overlap_chunk_second
chunks = [] chunks = []

View File

@@ -17,7 +17,11 @@ from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChu
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder from mistral_common.tokens.tokenizers.audio import (
Audio,
AudioEncoder,
TranscriptionFormat,
)
from transformers import BatchFeature, TensorType, WhisperConfig from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
@@ -157,6 +161,10 @@ class VoxtralProcessorAdapter:
# pad if necessary # pad if necessary
# TODO(Patrick) - remove once mistral-common is bumped # TODO(Patrick) - remove once mistral-common is bumped
if (
self._audio_processor.audio_config.transcription_format
!= TranscriptionFormat.STREAMING
):
sig = inspect.signature(self._audio_processor.pad) sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters: if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad( audio = self._audio_processor.pad(

View File

@@ -3,10 +3,19 @@
import math import math
from collections.abc import Mapping from collections.abc import Mapping
from typing import Literal, cast
import numpy as np
import torch import torch
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
StreamingMode,
TranscriptionRequest,
)
from mistral_common.tokens.tokenizers.audio import Audio
from vllm.config.vllm import VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.voxtral import ( from vllm.model_executor.models.voxtral import (
@@ -27,6 +36,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from .utils import ( from .utils import (
_flatten_embeddings, _flatten_embeddings,
@@ -205,25 +215,34 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
"For streaming you must provide an audio input at every step." "For streaming you must provide an audio input at every step."
) )
multiple_of = self.audio_config.raw_audio_length_per_tok def _truncate_left(
assert all( sample: torch.Tensor, mult_of: int, pos: int
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs ) -> torch.Tensor:
), ( assert pos in [0, 1], pos
f"Every input audio waveform has to be a multiple of {multiple_of}, but" if (ctx := sample.shape[pos] % mult_of) != 0:
f" one is {this_audio} with {(this_audio / multiple_of)=}." sample = sample[ctx:] if pos == 0 else sample[:, ctx:]
assert sample.shape[pos] > 0, (
f"Sample is empty after truncation with ctx {ctx}"
) )
return sample
mel_features = [ mel_features = [
self.whisper_encoder.compute_whisper_melspec(audio).to( self.whisper_encoder.compute_whisper_melspec(audio).to(
self.whisper_encoder.dtype self.whisper_encoder.dtype
) )
for audio in audio_inputs for audio in audio_inputs
] ]
# we truncate the left most mel feature
# if the sequence length in impair
mel_features = [_truncate_left(mel, 2, 1) for mel in mel_features]
seq_lens = [mel.shape[1] for mel in mel_features] seq_lens = [mel.shape[1] for mel in mel_features]
# [total_num_20ms_frames, hidden_size] # [total_num_20ms_frames, hidden_size]
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv( audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
mel_features mel_features
)[0] )
conv_stride = self.whisper_encoder.whisper_encoder.total_stride conv_stride = self.whisper_encoder.whisper_encoder.total_stride
audio_embeddings_per_sample = audio_embeddings.split( audio_embeddings_per_sample = audio_embeddings.split(
[s // conv_stride for s in seq_lens], dim=0 [s // conv_stride for s in seq_lens], dim=0
@@ -231,13 +250,55 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
# audio_embeddings per sample need to be divisible by 4 # audio_embeddings per sample need to be divisible by 4
pool_size = self.config.audio_config.block_pool_size pool_size = self.config.audio_config.block_pool_size
assert all(
(this_shape := sample.shape[0]) % pool_size == 0 audio_embeddings_per_sample = [
_truncate_left(sample, pool_size, 0)
for sample in audio_embeddings_per_sample for sample in audio_embeddings_per_sample
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}." ]
audio_embeddings_per_sample = [ audio_embeddings_per_sample = [
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size) e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
for e in audio_embeddings_per_sample for e in audio_embeddings_per_sample
] ]
return audio_embeddings_per_sample return audio_embeddings_per_sample
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
tokenizer = cached_tokenizer_from_config(model_config)
audio_config = tokenizer.instruct.audio_encoder.audio_config
sample_rate = audio_config.sampling_rate
return SpeechToTextConfig(
max_audio_clip_s=None, # only limited by memory
sample_rate=sample_rate,
min_energy_split_window_size=None,
)
@classmethod
# for speech-to-text transcription
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest(
model=model_config.model,
audio=RawAudio.from_audio(audio),
language=language,
streaming=StreamingMode.OFFLINE,
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)

View File

@@ -469,8 +469,10 @@ class WhisperEncoder(nn.Module):
self.max_source_positions = config.max_source_positions self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
is_causal = getattr(config, "is_causal", False) self.is_causal = getattr(config, "is_causal", False)
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1) Conv1d = (
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
)
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3) self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3) self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
@@ -485,7 +487,7 @@ class WhisperEncoder(nn.Module):
) )
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE: if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
raise ValueError( raise ValueError(
"Only NOPE position embeddings are supported " "Only NOPE position embeddings are supported "
f"for causal models, but got {self.pos_embed_type}" f"for causal models, but got {self.pos_embed_type}"
@@ -536,8 +538,11 @@ class WhisperEncoder(nn.Module):
hidden_states.append(embeds) hidden_states.append(embeds)
input_is_batched = embeds.ndim > 2 input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D # Input to MHA must be B x T x D
if input_is_batched: if input_is_batched or self.is_causal:
# Models using WhisperEncoder may handle batching internally. # Models using WhisperEncoder may handle batching internally.
# If WhisperEncoder is causal, sequences
# are not padded to have identical seq length (T)
# => concat over feature dim
hidden_states = torch.cat(hidden_states) hidden_states = torch.cat(hidden_states)
else: else:
hidden_states = torch.stack(hidden_states, dim=0) hidden_states = torch.stack(hidden_states, dim=0)