[torch.compile][BE] Modify cudagraph callable to check for is_forward_context_set (#36288)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -34,9 +34,6 @@ relies on caching artifacts to reduce start time, we must properly propagate the
|
||||
with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder
|
||||
components (see Compile Range Integration).
|
||||
|
||||
3. `with set_forward_context` context manager should be used around the nn.Module's forward call. This will properly forward the vllm_config which is needed
|
||||
for torch.compile integration.
|
||||
|
||||
### CompilationConfig
|
||||
|
||||
With the exception of `compile_mm_encoder: true`, the multimodal encoder will inherit from the same compilation config as the text LLM. We may extend
|
||||
|
||||
@@ -16,7 +16,11 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.forward_context import (
|
||||
BatchDescriptor,
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
@@ -224,6 +228,12 @@ class CUDAGraphWrapper:
|
||||
self.concrete_cudagraph_entries.clear()
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
|
||||
if not is_forward_context_available():
|
||||
# No forward context means we are outside the normal
|
||||
# inference path (e.g. a vision encoder forward pass).
|
||||
# Just run the underlying function without cudagraphs.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
@@ -38,7 +38,6 @@ from vllm.compilation.decorators import (
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -872,10 +871,7 @@ class Llama4ForConditionalGeneration(
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
with (
|
||||
set_forward_context(None, self.vllm_config),
|
||||
):
|
||||
return self._process_image_input(image_input)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -49,7 +49,6 @@ from vllm.compilation.decorators import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||
@@ -1207,13 +1206,12 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"]
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
@@ -1262,18 +1260,15 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"]
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=grid_thw_list
|
||||
)
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
|
||||
Reference in New Issue
Block a user