[Model] Use merge_by_field_config for MM models (M-N) (#26710)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -75,7 +75,6 @@ from .interfaces import (
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
@@ -97,28 +96,19 @@ class MolmoImageInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- nc: Number of crops (dynamic)
|
||||
- bnc: Batch size * number of images * number of crops (dynamic)
|
||||
- np: Number of patches
|
||||
- tp: Token sequence positions
|
||||
- pd: Patch dimension
|
||||
"""
|
||||
|
||||
images: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
|
||||
]
|
||||
# Number of crops may vary per batch and image, so pass it as a list.
|
||||
images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")]
|
||||
|
||||
image_masks: Annotated[
|
||||
torch.Tensor | list[torch.Tensor] | None,
|
||||
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
|
||||
]
|
||||
image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")]
|
||||
|
||||
image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")]
|
||||
"""An index tensor that maps image features to their corresponding patch tokens."""
|
||||
|
||||
image_input_idx: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
|
||||
]
|
||||
# An index tensor that maps image features to their corresponding patch tokens.
|
||||
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
@@ -1363,6 +1353,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
class MolmoForCausalLM(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
# vision backbone mapping
|
||||
@@ -1451,18 +1443,12 @@ class MolmoForCausalLM(
|
||||
if images is None:
|
||||
return None
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of num_crops. Got type: {type(num_crops)}"
|
||||
)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
img_patch_id = kwargs.pop("img_patch_id", None)
|
||||
if not isinstance(img_patch_id, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
|
||||
)
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
if isinstance(img_patch_id, torch.Tensor):
|
||||
img_patch_id = img_patch_id.item()
|
||||
|
||||
assert isinstance(img_patch_id, int)
|
||||
self.img_patch_id = img_patch_id
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
@@ -1481,17 +1467,9 @@ class MolmoForCausalLM(
|
||||
num_crops = image_input["num_crops"]
|
||||
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (
|
||||
None if image_masks is None else flatten_bn(image_masks, concat=True)
|
||||
)
|
||||
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
image_masks=(
|
||||
None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
|
||||
),
|
||||
image_features = self.vision_backbone(
|
||||
images=images.unsqueeze(0),
|
||||
image_masks=None if image_masks is None else image_masks.unsqueeze(0),
|
||||
).squeeze(0)
|
||||
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
@@ -1499,8 +1477,8 @@ class MolmoForCausalLM(
|
||||
results = []
|
||||
num_crops_list = num_crops.tolist()
|
||||
for feats, img_idx in zip(
|
||||
image_features_flat.split(num_crops_list),
|
||||
image_input_idx_flat.split(num_crops_list),
|
||||
image_features.split(num_crops_list),
|
||||
image_input_idx.split(num_crops_list),
|
||||
):
|
||||
is_valid = img_idx >= 0
|
||||
valid_img_idx = img_idx[is_valid]
|
||||
|
||||
Reference in New Issue
Block a user