[torch.compile][BE] Modify cudagraph callable to check for is_forward_context_set (#36288)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-03-16 12:42:34 -07:00
committed by GitHub
parent 0fefd00e6c
commit 714c6e0eab
4 changed files with 27 additions and 29 deletions

View File

@@ -34,9 +34,6 @@ relies on caching artifacts to reduce start time, we must properly propagate the
with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder
components (see Compile Range Integration).
3. `with set_forward_context` context manager should be used around the nn.Module's forward call. This will properly forward the vllm_config which is needed
for torch.compile integration.
### CompilationConfig
With the exception of `compile_mm_encoder: true`, the multimodal encoder will inherit from the same compilation config as the text LLM. We may extend

View File

@@ -16,7 +16,11 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.forward_context import (
BatchDescriptor,
get_forward_context,
is_forward_context_available,
)
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
@@ -224,6 +228,12 @@ class CUDAGraphWrapper:
self.concrete_cudagraph_entries.clear()
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
if not is_forward_context_available():
# No forward context means we are outside the normal
# inference path (e.g. a vision encoder forward pass).
# Just run the underlying function without cudagraphs.
return self.runnable(*args, **kwargs)
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode

View File

@@ -38,7 +38,6 @@ from vllm.compilation.decorators import (
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -872,10 +871,7 @@ class Llama4ForConditionalGeneration(
if image_input is None:
return []
with (
set_forward_context(None, self.vllm_config),
):
return self._process_image_input(image_input)
return self._process_image_input(image_input)
def forward(
self,

View File

@@ -49,7 +49,6 @@ from vllm.compilation.decorators import (
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.attention import MMEncoderAttention
@@ -1207,13 +1206,12 @@ class Qwen2_5_VLForConditionalGeneration(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"]
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
@@ -1262,18 +1260,15 @@ class Qwen2_5_VLForConditionalGeneration(
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"]
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw_list
)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d",
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size