Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)
This commit is contained in:
@@ -46,8 +46,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -57,6 +56,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@@ -84,6 +84,14 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
@@ -1138,6 +1146,30 @@ class MolmoProcessorWrapper:
|
||||
if image_input_idx is not None:
|
||||
feat_is_patch = image_input_idx >= 0
|
||||
|
||||
input_is_embed = torch.isin(
|
||||
input_ids,
|
||||
torch.tensor([
|
||||
self.image_patch_id,
|
||||
self.im_col_id,
|
||||
self.im_start_id,
|
||||
self.im_end_id,
|
||||
]),
|
||||
)
|
||||
embed_ids = input_ids[input_is_embed]
|
||||
embed_is_patch = embed_ids == self.image_patch_id
|
||||
assert embed_is_patch.sum() == feat_is_patch.sum()
|
||||
|
||||
# image_tokens = extra_joint + joint
|
||||
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
|
||||
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
|
||||
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
|
||||
assert len(embed_start) == len(embed_end) == len(images)
|
||||
|
||||
embed_is_patch = [
|
||||
embed_is_patch[start:end + 1]
|
||||
for start, end in zip(embed_start, embed_end)
|
||||
]
|
||||
|
||||
tilings = [
|
||||
self.select_tiling(
|
||||
image_width=image.size[0],
|
||||
@@ -1149,6 +1181,7 @@ class MolmoProcessorWrapper:
|
||||
assert num_crops.sum() == len(feat_is_patch)
|
||||
|
||||
outputs["feat_is_patch"] = feat_is_patch
|
||||
outputs["embed_is_patch"] = embed_is_patch
|
||||
outputs["num_crops"] = num_crops
|
||||
outputs["img_patch_id"] = self.image_patch_id
|
||||
|
||||
@@ -1187,13 +1220,17 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
pooling_size = processor.pooling_size
|
||||
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
base_image_input_size = processor.base_image_input_size
|
||||
base_image_input_d = processor.image_patch_size
|
||||
|
||||
extra = image_token_length_w * image_token_length_h
|
||||
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
|
||||
return extra + joint
|
||||
per_row = ncols // pooling_size + 1
|
||||
joint = per_row * (nrows // pooling_size) + 2
|
||||
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
|
||||
resize = (image_token_length + 1) * image_token_length + 2
|
||||
|
||||
return resize + joint
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
@@ -1291,6 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
"image", num_crops),
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@@ -1330,10 +1368,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
joint = ([img_start_id] + joint_row *
|
||||
((nrows + 1) // pooling_size) + [img_end_id])
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
extra_joint + joint,
|
||||
embed_token_id=img_patch_id,
|
||||
)
|
||||
image_tokens = extra_joint + joint
|
||||
return image_tokens
|
||||
|
||||
return [
|
||||
PromptInsertion(
|
||||
@@ -1439,6 +1475,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
raise ValueError("Incorrect type of feat_is_patch. "
|
||||
f"Got type: {type(feat_is_patch)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
@@ -1450,12 +1491,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
f"Got type: {type(img_patch_id)}")
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
@@ -1494,7 +1537,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -1508,7 +1556,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.img_patch_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
Reference in New Issue
Block a user