[Model] Use merge_by_field_config for MM models (M-N) (#26710)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-14 01:27:01 +08:00
committed by GitHub
parent e3b90c1ba2
commit afc47e4de7
11 changed files with 127 additions and 331 deletions

View File

@@ -75,7 +75,6 @@ from .interfaces import (
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
@@ -97,28 +96,19 @@ class MolmoImageInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- nc: Number of crops (dynamic)
- bnc: Batch size * number of images * number of crops (dynamic)
- np: Number of patches
- tp: Token sequence positions
- pd: Patch dimension
"""
images: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
]
# Number of crops may vary per batch and image, so pass it as a list.
images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")]
image_masks: Annotated[
torch.Tensor | list[torch.Tensor] | None,
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
]
image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")]
image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")]
"""An index tensor that maps image features to their corresponding patch tokens."""
image_input_idx: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# An index tensor that maps image features to their corresponding patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
@@ -1363,6 +1353,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
class MolmoForCausalLM(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
@@ -1451,18 +1443,12 @@ class MolmoForCausalLM(
if images is None:
return None
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of num_crops. Got type: {type(num_crops)}"
)
num_crops = flatten_bn(num_crops, concat=True)
img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor):
raise ValueError(
f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
)
self.img_patch_id = img_patch_id.flatten().unique().item()
if isinstance(img_patch_id, torch.Tensor):
img_patch_id = img_patch_id.item()
assert isinstance(img_patch_id, int)
self.img_patch_id = img_patch_id
return MolmoImageInputs(
images=images,
@@ -1481,17 +1467,9 @@ class MolmoForCausalLM(
num_crops = image_input["num_crops"]
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (
None if image_masks is None else flatten_bn(image_masks, concat=True)
)
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(
None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
),
image_features = self.vision_backbone(
images=images.unsqueeze(0),
image_masks=None if image_masks is None else image_masks.unsqueeze(0),
).squeeze(0)
# Only the features corresponding to patch tokens are relevant
@@ -1499,8 +1477,8 @@ class MolmoForCausalLM(
results = []
num_crops_list = num_crops.tolist()
for feats, img_idx in zip(
image_features_flat.split(num_crops_list),
image_input_idx_flat.split(num_crops_list),
image_features.split(num_crops_list),
image_input_idx.split(num_crops_list),
):
is_valid = img_idx >= 0
valid_img_idx = img_idx[is_valid]