[Voxtral Realtime] Fix engine crash on empty multimodal embeddings (#34862)

Signed-off-by: Tal Nir <tal@nervexneurotech.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Tal Nir
2026-02-19 02:21:47 -05:00
committed by GitHub
parent 7f51e93864
commit f75b61a9e9
2 changed files with 101 additions and 10 deletions

View File

@@ -299,13 +299,29 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Pass post-conv embeddings directly as input"""
# for realtime we simply flatten the multimodal embeddings
# to be in tensor format, we treat the input ids later
assert multimodal_embeddings is not None
assert len(multimodal_embeddings) > 0, (
"For realtime you must provide a multimodal_embedding at every step."
)
"""Pass post-conv embeddings directly as input.
For realtime models, multimodal embeddings are required at every
decode step. If they are missing (e.g. due to an empty audio
commit, encoder-cache eviction under GPU memory pressure, or a
client disconnect), return zero embeddings instead of crashing
the engine so that all other in-flight requests stay alive.
"""
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
logger.warning(
"Realtime model received empty multimodal embeddings "
"for %d input tokens. Returning zero embeddings to "
"avoid engine crash.",
input_ids.shape[0],
)
pool_size = self.config.audio_config.block_pool_size
embed_dim = self.config.audio_config.d_model * pool_size
return torch.zeros(
input_ids.shape[0],
embed_dim,
dtype=self.whisper_encoder.dtype,
device=input_ids.device,
)
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
return mm_embeds_flat
@@ -367,9 +383,12 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
"""Transform audio waveforms -> initial whisper post-conv embeddings"""
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
assert audio_inputs is not None, (
"For realtime you must provide an audio input at every step."
)
if audio_inputs is None:
logger.warning(
"Realtime model received no audio inputs in "
"embed_multimodal. Returning empty embeddings."
)
return []
def _truncate_left(
sample: torch.Tensor, mult_of: int, pos: int