[Misc] Clean up scatter_patch_features (#15559)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
@@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class InternVLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: NestedTensors
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
|
||||
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
|
||||
@@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: InternVLImageInputs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if (kwargs.get("v0_path", False)
|
||||
or image_input["type"] != "pixel_values"):
|
||||
if image_input["type"] != "pixel_values":
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
kwargs.update({"v0_path": True})
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
|
||||
Reference in New Issue
Block a user