[Model] Move multimodal_cpu_fields definition to field config (#30181)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-12-06 21:40:02 +08:00
committed by GitHub
parent 21bb323542
commit 671427efbf
15 changed files with 141 additions and 95 deletions

View File

@@ -787,10 +787,10 @@ class Glm4vVisionTransformer(nn.Module):
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
# patchify
x = x.to(device=self.device, dtype=self.dtype)
@@ -805,7 +805,8 @@ class Glm4vVisionTransformer(nn.Module):
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
@@ -1548,7 +1549,6 @@ class Glm4vForConditionalGeneration(
) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
@@ -1559,12 +1559,10 @@ class Glm4vForConditionalGeneration(
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)
def _process_video_input(
@@ -1572,7 +1570,6 @@ class Glm4vForConditionalGeneration(
) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
@@ -1588,15 +1585,11 @@ class Glm4vForConditionalGeneration(
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw.tolist()
)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: