[torch.compile][BE] Modify cudagraph callable to check for is_forward_context_set (#36288)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -38,7 +38,6 @@ from vllm.compilation.decorators import (
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -872,10 +871,7 @@ class Llama4ForConditionalGeneration(
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
with (
|
||||
set_forward_context(None, self.vllm_config),
|
||||
):
|
||||
return self._process_image_input(image_input)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -49,7 +49,6 @@ from vllm.compilation.decorators import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||
@@ -1207,13 +1206,12 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"]
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
@@ -1262,18 +1260,15 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"]
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=grid_thw_list
|
||||
)
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw_list,
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
|
||||
Reference in New Issue
Block a user