From f1c664545b954e30ac4887e32ce8f73b39310a9a Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 24 Feb 2026 16:33:35 +0800 Subject: [PATCH] Make voxtral compile friendly (#33959) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tugsbayasgalan Manlaibaatar Co-authored-by: Nicolò Lucchesi --- .../model_executor/models/voxtral_realtime.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/voxtral_realtime.py b/vllm/model_executor/models/voxtral_realtime.py index cc556ac82..8159daeb6 100644 --- a/vllm/model_executor/models/voxtral_realtime.py +++ b/vllm/model_executor/models/voxtral_realtime.py @@ -41,6 +41,7 @@ from vllm.multimodal.processing.processor import ( ) from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config +from vllm.utils.torch_utils import is_torch_equal_or_newer from .utils import ( _flatten_embeddings, @@ -337,9 +338,21 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim assert input_ids is not None pool_size = self.config.audio_config.block_pool_size - inputs_embeds = inputs_embeds.view( - inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size - ) + if is_torch_equal_or_newer("2.11"): + inputs_embeds = inputs_embeds.view( + inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size + ) + else: + # TODO Use reshape + clone to break the view chain and avoid output + # aliasing input bug in torch.compile's AOT autograd cache. + # Without clone(), if any downstream operation returns a view that's + # connected to this view of inputs_embeds, the AOT autograd cache + # fails to pickle the ViewMetaSequence containing SymInt shapes. + # This will be fixed in pytorch 2.11 and beyond. + # issue: https://github.com/pytorch/pytorch/issues/174299 + inputs_embeds = inputs_embeds.reshape( + inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size + ).clone() whisper_positions = _expand_tensor(positions, pool_size) audio_hidden_states = self.whisper_encoder.whisper_encoder(