[Core] Whisper support torch.compile (#30385)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2026-01-19 11:02:31 +01:00
committed by GitHub
parent c0a350ca73
commit 74c583bc50
5 changed files with 27 additions and 1 deletions

View File

@@ -156,7 +156,9 @@ def test_wer_correctness(
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
): ):
# TODO refactor to use `ASRDataset` # TODO refactor to use `ASRDataset`
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server: with RemoteOpenAIServer(
model_name, ["--enforce-eager"], max_wait_seconds=480
) as remote_server:
dataset = load_hf_dataset(dataset_repo) dataset = load_hf_dataset(dataset_repo)
if not max_concurrent_request: if not max_concurrent_request:

View File

@@ -25,6 +25,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.config.compilation import DynamicShapesType from vllm.config.compilation import DynamicShapesType
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -388,6 +389,12 @@ def _support_torch_compile(
if self.do_not_compile or torch.compiler.is_compiling(): if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
# enc-dec models where tensor shapes/types vary across invocations, preventing
# the capture of a single computational graph.
if is_forward_context_available() and get_forward_context().skip_compiled:
return self.forward(*args, **kwargs)
# if aot_compiled_fn is set, call it with partition wrapper context. # if aot_compiled_fn is set, call it with partition wrapper context.
# The partition wrapper must be active at runtime for CUDA graph # The partition wrapper must be active at runtime for CUDA graph
# capture to work correctly with inductor graph partitioning. # capture to work correctly with inductor graph partitioning.

View File

@@ -207,6 +207,9 @@ class ForwardContext:
ubatch_slices: UBatchSlices | None = None ubatch_slices: UBatchSlices | None = None
# If True, bypass the compiled model call, e.g. by using .forward() directly
skip_compiled: bool = False
additional_kwargs: dict[str, Any] = field(default_factory=dict) additional_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
@@ -240,6 +243,7 @@ def create_forward_context(
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
additional_kwargs: dict[str, Any] | None = None, additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
): ):
return ForwardContext( return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context, no_compile_layers=vllm_config.compilation_config.static_forward_context,
@@ -249,6 +253,7 @@ def create_forward_context(
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
skip_compiled=skip_compiled,
additional_kwargs=additional_kwargs or {}, additional_kwargs=additional_kwargs or {},
) )
@@ -278,6 +283,7 @@ def set_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
skip_compiled: bool = False,
): ):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
@@ -336,6 +342,7 @@ def set_forward_context(
batch_descriptor, batch_descriptor,
ubatch_slices, ubatch_slices,
additional_kwargs, additional_kwargs,
skip_compiled,
) )
try: try:

View File

@@ -19,6 +19,7 @@ from transformers import (
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@@ -561,6 +562,7 @@ class WhisperEncoder(nn.Module):
return self.forward_layers(hidden_states) return self.forward_layers(hidden_states)
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
class WhisperDecoder(nn.Module): class WhisperDecoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()

View File

@@ -3268,6 +3268,13 @@ class GPUModelRunner(
# Mark KV scales as calculated after the first forward pass # Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False self.calculate_kv_scales = False
# Encoder-decoder models can only compile the pure decode steps where no
# encoder inputs are present. Use eager for the first pass.
num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs)
has_encoder_input = (
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
)
# Run the model. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with ( with (
@@ -3279,6 +3286,7 @@ class GPUModelRunner(
cudagraph_runtime_mode=cudagraph_mode, cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
skip_compiled=has_encoder_input,
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,