[Core] Whisper support torch.compile (#30385)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2026-01-19 11:02:31 +01:00
committed by GitHub
parent c0a350ca73
commit 74c583bc50
5 changed files with 27 additions and 1 deletions

View File

@@ -19,6 +19,7 @@ from transformers import (
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -561,6 +562,7 @@ class WhisperEncoder(nn.Module):
return self.forward_layers(hidden_states)
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
class WhisperDecoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()