[V1] Initial support of multimodal models for V1 re-arch (#10699)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2024-12-08 04:50:51 -08:00
committed by GitHub
parent fd57d2b534
commit a11f326528
11 changed files with 283 additions and 68 deletions

View File

@@ -26,7 +26,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict):
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
data: NestedTensors
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
@@ -349,10 +355,32 @@ class InternVLInputPipeline:
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
img_context_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False)
assert len(img_context_token_id) == 1, \
(f"Invalid image token '{self.img_context_token}': A valid image "
f"token encodes to a single token ID, got {img_context_token_id}.")
img_context_token_id = img_context_token_id[0]
return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
# Get precise tracking of placeholder positions
token_idx = image_idx = 0
placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == img_context_token_id:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
def input_mapper(
self,
@@ -614,26 +642,46 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
patches_per_image = []
for request_pixel_values in pixel_values:
for image_pixel_values in request_pixel_values:
patches_per_image.append(image_pixel_values.shape[0])
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
patches_per_image=patches_per_image)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
) -> Tuple[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])
patches_per_image = image_input["patches_per_image"]
if len(patches_per_image) == 1:
image_embeds = image_embeds.unsqueeze(0)
return image_embeds
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in patches_per_image
]
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -696,13 +744,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"inputs_embeds": inputs_embeds,
}
# Only required if the model is mono-architecture
if self.visual_token_mask is not None:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
self.img_context_token_id = None
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states