[V1] Scatter and gather placeholders in the model runner (#16076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
This commit is contained in:
@@ -33,7 +33,6 @@ from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -50,7 +49,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -58,9 +57,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Llama4ImagePatchInputs(TypedDict):
|
||||
@@ -77,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict):
|
||||
This is used to split the embeddings which has the first two dimensions
|
||||
flattened just like `flat_data`.
|
||||
"""
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
"""
|
||||
|
||||
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A list of aspect ratios corresponding to the number of tiles
|
||||
@@ -510,11 +502,10 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
# image_start + local tiles * (patches + 1 x separator) +
|
||||
# 1 global tile * (image x 1 + patches) + image_end
|
||||
token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
|
||||
mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2
|
||||
return {"image": mm_max_tokens}
|
||||
patch_per_chunk = self.get_patch_per_chunk(vision_config)
|
||||
num_patches = self.get_max_num_tiles() + 1
|
||||
|
||||
return {"image": patch_per_chunk * num_patches}
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
@@ -523,6 +514,14 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
return ImageSize(height=self.get_max_num_tiles() * image_size,
|
||||
width=image_size)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
|
||||
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
):
|
||||
@@ -578,33 +577,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
for (r_h, r_w) in aspect_ratios
|
||||
]
|
||||
|
||||
# embed_is_patch should have one feature per image-related token:
|
||||
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
|
||||
# -> False
|
||||
# <|patch|> -> True
|
||||
# embed_is_patch has no entries corresponding to non-image-related
|
||||
# tokens.
|
||||
patch_id = tokenizer.get_vocab()[processor.img_patch_token]
|
||||
num_patches_per_chunk = self.info.get_patch_per_chunk(
|
||||
vision_config)
|
||||
expanded_image_tokens_list = [
|
||||
processor._prompt_split_image(aspect_ratio,
|
||||
num_patches_per_chunk)
|
||||
for aspect_ratio in aspect_ratios
|
||||
]
|
||||
expanded_image_token_ids = [
|
||||
tokenizer.encode(image_tokens, add_special_tokens=False)
|
||||
for image_tokens in expanded_image_tokens_list
|
||||
]
|
||||
embed_is_patch = [
|
||||
torch.tensor(tokens) == patch_id
|
||||
for tokens in expanded_image_token_ids
|
||||
]
|
||||
|
||||
processed_outputs["aspect_ratios"] = aspect_ratios
|
||||
processed_outputs["patches_per_image"] = torch.tensor(
|
||||
patches_per_image)
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -619,7 +594,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
"image", patches_per_image),
|
||||
patches_per_image=MultiModalFieldConfig.batched("image"),
|
||||
aspect_ratios=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -639,12 +613,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_token = hf_processor.image_token
|
||||
img_patch_token = hf_processor.img_patch_token
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
|
||||
return hf_processor._prompt_split_image(
|
||||
|
||||
repl = hf_processor._prompt_split_image(
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_patches_per_chunk=num_patches_per_chunk)
|
||||
num_patches_per_chunk=num_patches_per_chunk,
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_text(repl, img_patch_token)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@@ -737,11 +716,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
flat_pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
aspect_ratios = kwargs.pop("aspect_ratios", None)
|
||||
if not isinstance(aspect_ratios, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of aspect_ratios. "
|
||||
@@ -751,7 +725,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
type="pixel_values",
|
||||
flat_data=flat_pixel_values,
|
||||
patches_per_image=patches_per_image,
|
||||
embed_is_patch=embed_is_patch,
|
||||
aspect_ratios=aspect_ratios,
|
||||
)
|
||||
|
||||
@@ -759,10 +732,15 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
|
||||
flat_data = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
vision_embeddings_flat = self.multi_modal_projector(
|
||||
vision_embeddings_flat)
|
||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
|
||||
return [
|
||||
img.flatten(0, 1)
|
||||
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
]
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs) -> Optional[MultiModalEmbeddings]:
|
||||
@@ -770,20 +748,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
# num_images x [num_chunks, num_patches, hidden_dim]
|
||||
image_features = self._process_image_input(image_input)
|
||||
# num_images x [num_chunks x num_patches, hidden_dim]
|
||||
image_features_flat = [img.flatten(0, 1) for img in image_features]
|
||||
# num_images x [1, input_len] -> num_images x [input_len]
|
||||
embed_is_patch_flat = [
|
||||
is_patch.flatten(0, 1)
|
||||
for is_patch in image_input["embed_is_patch"]
|
||||
]
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features_flat,
|
||||
embed_is_patch_flat,
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -794,9 +759,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.config.image_token_index)
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user