diff --git a/vllm/model_executor/models/voxtral_realtime.py b/vllm/model_executor/models/voxtral_realtime.py index cc556ac82..8159daeb6 100644 --- a/vllm/model_executor/models/voxtral_realtime.py +++ b/vllm/model_executor/models/voxtral_realtime.py @@ -41,6 +41,7 @@ from vllm.multimodal.processing.processor import ( ) from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config +from vllm.utils.torch_utils import is_torch_equal_or_newer from .utils import ( _flatten_embeddings, @@ -337,9 +338,21 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim assert input_ids is not None pool_size = self.config.audio_config.block_pool_size - inputs_embeds = inputs_embeds.view( - inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size - ) + if is_torch_equal_or_newer("2.11"): + inputs_embeds = inputs_embeds.view( + inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size + ) + else: + # TODO Use reshape + clone to break the view chain and avoid output + # aliasing input bug in torch.compile's AOT autograd cache. + # Without clone(), if any downstream operation returns a view that's + # connected to this view of inputs_embeds, the AOT autograd cache + # fails to pickle the ViewMetaSequence containing SymInt shapes. + # This will be fixed in pytorch 2.11 and beyond. + # issue: https://github.com/pytorch/pytorch/issues/174299 + inputs_embeds = inputs_embeds.reshape( + inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size + ).clone() whisper_positions = _expand_tensor(positions, pool_size) audio_hidden_states = self.whisper_encoder.whisper_encoder(