From 74c583bc508c2dafb9e95bab3b635884e4a021f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Mon, 19 Jan 2026 11:02:31 +0100 Subject: [PATCH] [Core] Whisper support `torch.compile` (#30385) Signed-off-by: NickLucche --- .../correctness/test_transcription_api_correctness.py | 4 +++- vllm/compilation/decorators.py | 7 +++++++ vllm/forward_context.py | 7 +++++++ vllm/model_executor/models/whisper.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 8 ++++++++ 5 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 7821ade63..2725a1295 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -156,7 +156,9 @@ def test_wer_correctness( model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None ): # TODO refactor to use `ASRDataset` - with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server: + with RemoteOpenAIServer( + model_name, ["--enforce-eager"], max_wait_seconds=480 + ) as remote_server: dataset = load_hf_dataset(dataset_repo) if not max_concurrent_request: diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 7d9fd0d2f..3d9d77421 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -25,6 +25,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.config.compilation import DynamicShapesType +from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import resolve_obj_by_qualname @@ -388,6 +389,12 @@ def _support_torch_compile( if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) + # If skip_compiled is set, bypass compiled model call. This is used e.g. for + # enc-dec models where tensor shapes/types vary across invocations, preventing + # the capture of a single computational graph. + if is_forward_context_available() and get_forward_context().skip_compiled: + return self.forward(*args, **kwargs) + # if aot_compiled_fn is set, call it with partition wrapper context. # The partition wrapper must be active at runtime for CUDA graph # capture to work correctly with inductor graph partitioning. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 9ef0569e8..ed91af44a 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -207,6 +207,9 @@ class ForwardContext: ubatch_slices: UBatchSlices | None = None + # If True, bypass the compiled model call, e.g. by using .forward() directly + skip_compiled: bool = False + additional_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): @@ -240,6 +243,7 @@ def create_forward_context( batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, additional_kwargs: dict[str, Any] | None = None, + skip_compiled: bool = False, ): return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, @@ -249,6 +253,7 @@ def create_forward_context( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, + skip_compiled=skip_compiled, additional_kwargs=additional_kwargs or {}, ) @@ -278,6 +283,7 @@ def set_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + skip_compiled: bool = False, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -336,6 +342,7 @@ def set_forward_context( batch_descriptor, ubatch_slices, additional_kwargs, + skip_compiled, ) try: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 379a61cea..8d6726145 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -19,6 +19,7 @@ from transformers import ( from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size @@ -561,6 +562,7 @@ class WhisperEncoder(nn.Module): return self.forward_layers(hidden_states) +@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1}) class WhisperDecoder(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5691a7698..32a07d64a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3268,6 +3268,13 @@ class GPUModelRunner( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False + # Encoder-decoder models can only compile the pure decode steps where no + # encoder inputs are present. Use eager for the first pass. + num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs) + has_encoder_input = ( + self.model_config.is_encoder_decoder and num_encoder_reqs > 0 + ) + # Run the model. # Use persistent buffers for CUDA graphs. with ( @@ -3279,6 +3286,7 @@ class GPUModelRunner( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,