[Core][MM] Add mechanism to configure multimodal fields which should stay on CPU (#28168)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -414,16 +414,10 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
def device(self) -> torch.device:
|
||||
return self.patch_embed.proj.weight.device
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
def rot_pos_emb(self, grid_thw: list[list[int]]):
|
||||
pos_ids = []
|
||||
# Support both Tensor and list inputs for DP path
|
||||
if isinstance(grid_thw, list):
|
||||
grid_list = grid_thw
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_list)
|
||||
else:
|
||||
grid_list = grid_thw.tolist()
|
||||
max_grid_size = int(grid_thw[:, 1:].max().item())
|
||||
for t, h, w in grid_list:
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
@@ -527,24 +521,25 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||
else:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
|
||||
|
||||
grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0]
|
||||
).cumsum(
|
||||
dim=0,
|
||||
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
@@ -1177,6 +1172,7 @@ class Qwen3VLForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
merge_by_field_config = True
|
||||
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@@ -1356,7 +1352,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
@@ -1364,18 +1359,14 @@ class Qwen3VLForConditionalGeneration(
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
||||
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = (
|
||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
||||
// (merge_size * merge_size)
|
||||
).tolist()
|
||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||
return image_embeds.split(sizes)
|
||||
|
||||
def _process_video_input(
|
||||
@@ -1383,7 +1374,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
@@ -1392,19 +1382,16 @@ class Qwen3VLForConditionalGeneration(
|
||||
self.visual.dtype
|
||||
)
|
||||
if self.use_data_parallel:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = (
|
||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
||||
// (merge_size * merge_size)
|
||||
).tolist()
|
||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||
return video_embeds.split(sizes)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user