Make voxtral compile friendly (#33959)
Signed-off-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
committed by
GitHub
parent
c870eb9e0f
commit
f1c664545b
@@ -41,6 +41,7 @@ from vllm.multimodal.processing.processor import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .utils import (
|
||||
_flatten_embeddings,
|
||||
@@ -337,9 +338,21 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
|
||||
assert input_ids is not None
|
||||
|
||||
pool_size = self.config.audio_config.block_pool_size
|
||||
if is_torch_equal_or_newer("2.11"):
|
||||
inputs_embeds = inputs_embeds.view(
|
||||
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
|
||||
)
|
||||
else:
|
||||
# TODO Use reshape + clone to break the view chain and avoid output
|
||||
# aliasing input bug in torch.compile's AOT autograd cache.
|
||||
# Without clone(), if any downstream operation returns a view that's
|
||||
# connected to this view of inputs_embeds, the AOT autograd cache
|
||||
# fails to pickle the ViewMetaSequence containing SymInt shapes.
|
||||
# This will be fixed in pytorch 2.11 and beyond.
|
||||
# issue: https://github.com/pytorch/pytorch/issues/174299
|
||||
inputs_embeds = inputs_embeds.reshape(
|
||||
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
|
||||
).clone()
|
||||
|
||||
whisper_positions = _expand_tensor(positions, pool_size)
|
||||
audio_hidden_states = self.whisper_encoder.whisper_encoder(
|
||||
|
||||
Reference in New Issue
Block a user