Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, Set, Tuple, TypedDict
|
||||
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@@ -65,6 +66,14 @@ class FuyuImagePatchInputs(TypedDict):
|
||||
flattened just like `flat_data`.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
@@ -85,7 +94,15 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
max_ncols, max_nrows = self.get_image_feature_grid_size(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
max_image_tokens = (max_ncols + 1) * max_nrows
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
def get_image_feature_grid_size(
|
||||
self,
|
||||
@@ -111,32 +128,11 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
nrows = math.ceil(image_height / patch_height)
|
||||
return ncols, nrows
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
ncols, nrows = self.get_image_feature_grid_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
return ncols * nrows
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self.get_image_processor()
|
||||
return ImageSize(width=image_processor.size["width"],
|
||||
height=image_processor.size["height"])
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
|
||||
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
||||
|
||||
@@ -196,6 +192,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
|
||||
# get patch grid size for each image
|
||||
embed_is_patch = []
|
||||
for image in images:
|
||||
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||
image_width=image.width,
|
||||
image_height=image.height,
|
||||
)
|
||||
|
||||
mask = torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
embed_is_patch.append(mask)
|
||||
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _apply_hf_processor_tokens_only(
|
||||
@@ -215,7 +224,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -242,9 +252,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=_IMAGE_TOKEN_ID,
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -319,13 +329,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of image patches. "
|
||||
f"Got type: {type(image_patches)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=self._validate_pixel_values(
|
||||
flatten_bn(image_patches_flat, concat=True)),
|
||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -347,7 +364,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -357,11 +379,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
_IMAGE_TOKEN_ID,
|
||||
)
|
||||
input_ids, inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user