[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-18 17:52:00 +08:00
committed by GitHub
parent 5c79b0d648
commit 27e8d1ea3e
77 changed files with 431 additions and 383 deletions

View File

@@ -370,10 +370,16 @@ def _assert_inputs_equal(
if ignore_mm_keys is None:
ignore_mm_keys = set()
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
assert a_rest == b_rest, msg
a_data = a["mm_kwargs"].get_data()
b_data = b["mm_kwargs"].get_data()
for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None)
a_data.pop(key, None)
b_data.pop(key, None)
assert a == b, msg
assert a_data == b_data, msg

View File

@@ -45,7 +45,8 @@ def test_processor_override(
video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token)
video_tok_count = processed_inputs["prompt_token_ids"].count(
video_token_id)
grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0]
grid_t, _, _ = processed_inputs["mm_kwargs"].get_data(
)["video_grid_thw"][0]
assert grid_t == expected_grid_t
assert video_tok_count == expected_toks_per_frame * grid_t

View File

@@ -108,7 +108,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches

View File

@@ -68,7 +68,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches

View File

@@ -51,14 +51,14 @@ def test_processor_override(
prompt = encode_tokens(tokenizer, prompt)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
mm_kwargs = processed_inputs["mm_kwargs"]
mm_data = processed_inputs["mm_kwargs"].get_data()
# place holder replacements
prompt_token_ids = processed_inputs["prompt_token_ids"]
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
aspect_ratios = mm_kwargs["aspect_ratios"]
aspect_ratios = mm_data["aspect_ratios"]
num_x_separators = num_y_separators = 0
for tiles_y, tiles_x in aspect_ratios:
if tiles_x * tiles_y > 1:
@@ -80,6 +80,6 @@ def test_processor_override(
num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \
== mm_kwargs["patches_per_image"].sum()
== sum(mm_data["patches_per_image"]) * num_patches_per_chunk
assert len(mm_data["pixel_values"]) \
== sum(mm_data["patches_per_image"])

View File

@@ -49,18 +49,18 @@ def test_profiling(
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
] * max_num_seqs
mm_kwargs = processor.apply(
mm_data = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
)["mm_kwargs"].get_data()
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")]
num_tiles = [[t] for t in mm_data.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles

View File

@@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int):
hf_config = ctx.get_hf_config(Llama4Config)
mm_kwargs = processor.apply(
mm_data = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
)["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
chunks_per_image = prod(mm_kwargs["patches_per_image"])
chunks_per_image = prod(mm_data["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
0][1] # x-y seperator tokens
num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][
1] # x-y seperator tokens
total_tokens = total_num_patches.item() + num_tiles.item(
) + 3 # image start, image, image end

View File

@@ -70,7 +70,8 @@ def _run_check(
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values_flat"].shape
print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape)
assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches

View File

@@ -48,7 +48,8 @@ def test_processor_override(
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
pixel_shape = processed_inputs["mm_kwargs"].get_data(
)["pixel_values"].shape
assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs