[Model] Move multimodal_cpu_fields definition to field config (#30181)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -787,10 +787,10 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
# Convert grid_thw to tensor (always expecting list format now)
|
||||
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
@@ -805,7 +805,8 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||
|
||||
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
@@ -1548,7 +1549,6 @@ class Glm4vForConditionalGeneration(
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
@@ -1559,12 +1559,10 @@ class Glm4vForConditionalGeneration(
|
||||
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
|
||||
)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = (
|
||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
||||
// (merge_size * merge_size)
|
||||
).tolist()
|
||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||
return image_embeds.split(sizes)
|
||||
|
||||
def _process_video_input(
|
||||
@@ -1572,7 +1570,6 @@ class Glm4vForConditionalGeneration(
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
@@ -1588,15 +1585,11 @@ class Glm4vForConditionalGeneration(
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=grid_thw.tolist()
|
||||
)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = (
|
||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
||||
// (merge_size * merge_size)
|
||||
).tolist()
|
||||
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||
return video_embeds.split(sizes)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
|
||||
@@ -563,7 +563,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
|
||||
)
|
||||
|
||||
|
||||
@@ -786,8 +786,6 @@ class HunYuanVLForConditionalGeneration(
|
||||
SupportsQuant,
|
||||
SupportsXDRoPE,
|
||||
):
|
||||
multimodal_cpu_fields = {"image_grid_thw"}
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
|
||||
@@ -84,9 +84,9 @@ class SupportsMultiModal(Protocol):
|
||||
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
||||
"""
|
||||
|
||||
multimodal_cpu_fields: ClassVar[Set[str]] = frozenset()
|
||||
multimodal_cpu_fields: ClassVar[Set[str] | None] = None
|
||||
"""
|
||||
A set indicating CPU-only multimodal fields.
|
||||
[DEPRECATED] A set indicating CPU-only multimodal fields.
|
||||
"""
|
||||
|
||||
_processor_factory: ClassVar[_ProcessorFactories]
|
||||
@@ -279,6 +279,15 @@ def supports_multimodal(
|
||||
"please remove the override from your model."
|
||||
)
|
||||
|
||||
multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None)
|
||||
if multimodal_cpu_fields is not None:
|
||||
raise ValueError(
|
||||
"`multimodal_cpu_fields` is no longer effective, "
|
||||
"please set `keep_on_cpu=True` in `MultiModalFieldConfig` "
|
||||
"(refer to https://github.com/vllm-project/vllm/pull/30181), "
|
||||
"and then remove the override from your model."
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@@ -201,8 +201,6 @@ class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder):
|
||||
dummy_inputs=OpenCUADummyInputsBuilder,
|
||||
)
|
||||
class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
multimodal_cpu_fields = {"image_grid_thw"}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
|
||||
@@ -1039,8 +1039,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
SupportsMultiModalPruning,
|
||||
SupportsMRoPE,
|
||||
):
|
||||
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
|
||||
@@ -811,14 +811,14 @@ def _create_qwen2vl_field_factory(
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_embed_grid_sizes
|
||||
),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes
|
||||
),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_embed_grid_sizes
|
||||
),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||
)
|
||||
|
||||
return _qwen2vl_field_config
|
||||
@@ -1131,8 +1131,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
|
||||
class Qwen2VLForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@@ -1393,9 +1391,11 @@ class Qwen2VLForConditionalGeneration(
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"]
|
||||
if self.use_data_parallel:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
||||
self.visual,
|
||||
pixel_values_videos,
|
||||
grid_thw.tolist(),
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
@@ -984,14 +984,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes
|
||||
),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes
|
||||
),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes
|
||||
),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -1190,8 +1190,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
SupportsMRoPE,
|
||||
SupportsEagle3,
|
||||
):
|
||||
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
Reference in New Issue
Block a user