[Core][MM] Add mechanism to configure multimodal fields which should stay on CPU (#28168)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2025-11-07 12:14:29 +00:00
committed by GitHub
parent 8e19d470af
commit e0919f331d
7 changed files with 68 additions and 74 deletions

View File

@@ -798,21 +798,27 @@ class Qwen2VisionTransformer(nn.Module):
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
else:
grid_thw_list = grid_thw.tolist()
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
# compute cu_seqlens
grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
cu_seqlens = torch.repeat_interleave(
grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0]
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# transformers
x = x.unsqueeze(1)
@@ -1211,6 +1217,7 @@ class Qwen2VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
merge_by_field_config = True
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@@ -1458,7 +1465,6 @@ class Qwen2VLForConditionalGeneration(
) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"]
@@ -1467,18 +1473,14 @@ class Qwen2VLForConditionalGeneration(
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)
def _process_video_input(
@@ -1486,26 +1488,22 @@ class Qwen2VLForConditionalGeneration(
) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"]
else:
pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: