Make voxtral compile friendly (#33959)

Signed-off-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2026-02-24 16:33:35 +08:00
committed by GitHub
parent c870eb9e0f
commit f1c664545b

View File

@@ -41,6 +41,7 @@ from vllm.multimodal.processing.processor import (
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .utils import (
_flatten_embeddings,
@@ -337,9 +338,21 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
assert input_ids is not None
pool_size = self.config.audio_config.block_pool_size
if is_torch_equal_or_newer("2.11"):
inputs_embeds = inputs_embeds.view(
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
)
else:
# TODO Use reshape + clone to break the view chain and avoid output
# aliasing input bug in torch.compile's AOT autograd cache.
# Without clone(), if any downstream operation returns a view that's
# connected to this view of inputs_embeds, the AOT autograd cache
# fails to pickle the ViewMetaSequence containing SymInt shapes.
# This will be fixed in pytorch 2.11 and beyond.
# issue: https://github.com/pytorch/pytorch/issues/174299
inputs_embeds = inputs_embeds.reshape(
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
).clone()
whisper_positions = _expand_tensor(positions, pool_size)
audio_hidden_states = self.whisper_encoder.whisper_encoder(