From 714c6e0eab76a4fb1394089d848ecfe46408b9c9 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Mon, 16 Mar 2026 12:42:34 -0700 Subject: [PATCH] [torch.compile][BE] Modify cudagraph callable to check for is_forward_context_set (#36288) Signed-off-by: Lucas Kabela --- docs/design/torch_compile_multimodal.md | 3 -- vllm/compilation/cuda_graph.py | 12 +++++++- vllm/model_executor/models/mllama4.py | 6 +--- vllm/model_executor/models/qwen2_5_vl.py | 35 ++++++++++-------------- 4 files changed, 27 insertions(+), 29 deletions(-) diff --git a/docs/design/torch_compile_multimodal.md b/docs/design/torch_compile_multimodal.md index 4abf1d08c..c46bfa832 100644 --- a/docs/design/torch_compile_multimodal.md +++ b/docs/design/torch_compile_multimodal.md @@ -34,9 +34,6 @@ relies on caching artifacts to reduce start time, we must properly propagate the with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder components (see Compile Range Integration). -3. `with set_forward_context` context manager should be used around the nn.Module's forward call. This will properly forward the vllm_config which is needed -for torch.compile integration. - ### CompilationConfig With the exception of `compile_mm_encoder: true`, the multimodal encoder will inherit from the same compilation config as the text LLM. We may extend diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 13e88448c..78841866f 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -16,7 +16,11 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id -from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.forward_context import ( + BatchDescriptor, + get_forward_context, + is_forward_context_available, +) from vllm.logger import init_logger from vllm.model_executor.offloader.base import get_offloader from vllm.platforms import current_platform @@ -224,6 +228,12 @@ class CUDAGraphWrapper: self.concrete_cudagraph_entries.clear() def __call__(self, *args: Any, **kwargs: Any) -> Any | None: + if not is_forward_context_available(): + # No forward context means we are outside the normal + # inference path (e.g. a vision encoder forward pass). + # Just run the underlying function without cudagraphs. + return self.runnable(*args, **kwargs) + forward_context = get_forward_context() batch_descriptor = forward_context.batch_descriptor cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index da9836a95..a36b1fa57 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -38,7 +38,6 @@ from vllm.compilation.decorators import ( from vllm.config import VllmConfig, set_current_vllm_config from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import set_forward_context from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( @@ -872,10 +871,7 @@ class Llama4ForConditionalGeneration( if image_input is None: return [] - with ( - set_forward_context(None, self.vllm_config), - ): - return self._process_image_input(image_input) + return self._process_image_input(image_input) def forward( self, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8e50022f0..ed311ce05 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -49,7 +49,6 @@ from vllm.compilation.decorators import ( from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.attention import MMEncoderAttention @@ -1207,13 +1206,12 @@ class Qwen2_5_VLForConditionalGeneration( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - with set_forward_context(None, self.vllm_config): - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" - ) - else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size @@ -1262,18 +1260,15 @@ class Qwen2_5_VLForConditionalGeneration( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - with set_forward_context(None, self.vllm_config): - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model( - self.visual, - pixel_values_videos, - grid_thw_list, - rope_type="rope_3d", - ) - else: - video_embeds = self.visual( - pixel_values_videos, grid_thw=grid_thw_list - ) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d", + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size