[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -370,10 +370,16 @@ def _assert_inputs_equal(
|
||||
if ignore_mm_keys is None:
|
||||
ignore_mm_keys = set()
|
||||
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
||||
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
|
||||
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
|
||||
|
||||
assert a_rest == b_rest, msg
|
||||
|
||||
a_data = a["mm_kwargs"].get_data()
|
||||
b_data = b["mm_kwargs"].get_data()
|
||||
|
||||
for key in ignore_mm_keys:
|
||||
a["mm_kwargs"].pop(key, None)
|
||||
b["mm_kwargs"].pop(key, None)
|
||||
a_data.pop(key, None)
|
||||
b_data.pop(key, None)
|
||||
|
||||
assert a == b, msg
|
||||
assert a_data == b_data, msg
|
||||
|
||||
@@ -45,7 +45,8 @@ def test_processor_override(
|
||||
video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token)
|
||||
video_tok_count = processed_inputs["prompt_token_ids"].count(
|
||||
video_token_id)
|
||||
grid_t, _, _ = processed_inputs["mm_kwargs"]["video_grid_thw"][0]
|
||||
grid_t, _, _ = processed_inputs["mm_kwargs"].get_data(
|
||||
)["video_grid_thw"][0]
|
||||
|
||||
assert grid_t == expected_grid_t
|
||||
assert video_tok_count == expected_toks_per_frame * grid_t
|
||||
|
||||
@@ -108,7 +108,8 @@ def _run_check(
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
|
||||
pixel_shape = processed_inputs["mm_kwargs"].get_data(
|
||||
)["pixel_values_flat"].shape
|
||||
|
||||
assert img_tok_count == 256 * total_expected_num_patches
|
||||
assert pixel_shape[0] == total_expected_num_patches
|
||||
|
||||
@@ -68,7 +68,8 @@ def _run_check(
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
|
||||
pixel_shape = processed_inputs["mm_kwargs"].get_data(
|
||||
)["pixel_values_flat"].shape
|
||||
|
||||
assert img_tok_count == 256 * total_expected_num_patches
|
||||
assert pixel_shape[0] == total_expected_num_patches
|
||||
|
||||
@@ -51,14 +51,14 @@ def test_processor_override(
|
||||
prompt = encode_tokens(tokenizer, prompt)
|
||||
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
mm_kwargs = processed_inputs["mm_kwargs"]
|
||||
mm_data = processed_inputs["mm_kwargs"].get_data()
|
||||
|
||||
# place holder replacements
|
||||
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
||||
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
|
||||
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
|
||||
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
|
||||
aspect_ratios = mm_kwargs["aspect_ratios"]
|
||||
aspect_ratios = mm_data["aspect_ratios"]
|
||||
num_x_separators = num_y_separators = 0
|
||||
for tiles_y, tiles_x in aspect_ratios:
|
||||
if tiles_x * tiles_y > 1:
|
||||
@@ -80,6 +80,6 @@ def test_processor_override(
|
||||
num_patches_per_chunk = processor.info.get_patch_per_chunk(
|
||||
config.vision_config)
|
||||
assert prompt_token_ids.count(config.image_token_index) \
|
||||
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
|
||||
assert mm_kwargs["pixel_values"].shape[0] \
|
||||
== mm_kwargs["patches_per_image"].sum()
|
||||
== sum(mm_data["patches_per_image"]) * num_patches_per_chunk
|
||||
assert len(mm_data["pixel_values"]) \
|
||||
== sum(mm_data["patches_per_image"])
|
||||
|
||||
@@ -49,18 +49,18 @@ def test_profiling(
|
||||
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
|
||||
] * max_num_seqs
|
||||
|
||||
mm_kwargs = processor.apply(
|
||||
mm_data = processor.apply(
|
||||
prompt=dummy_mm_data.prompt,
|
||||
mm_data=dummy_mm_data.mm_data,
|
||||
hf_processor_mm_kwargs=dict(),
|
||||
)["mm_kwargs"]
|
||||
)["mm_kwargs"].get_data()
|
||||
|
||||
# Get the actual number of encoder tokens for each sample.
|
||||
# Because attn_metadata.encoder_seq_lens only counts the last
|
||||
# group of images for each sample, which is used to cheat the
|
||||
# block manager to allocate blocks for those images only.
|
||||
# See MllamaMultiModalProcessor for more details.
|
||||
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")]
|
||||
num_tiles = [[t] for t in mm_data.pop("num_tiles")]
|
||||
num_tokens_per_tile = calc_token_per_chunk(image_size)
|
||||
actual_encoder_seq_lens = [
|
||||
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
|
||||
|
||||
@@ -38,21 +38,21 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
|
||||
hf_config = ctx.get_hf_config(Llama4Config)
|
||||
|
||||
mm_kwargs = processor.apply(
|
||||
mm_data = processor.apply(
|
||||
prompt=dummy_mm_data.prompt,
|
||||
mm_data=dummy_mm_data.mm_data,
|
||||
hf_processor_mm_kwargs=dict(),
|
||||
)["mm_kwargs"]
|
||||
)["mm_kwargs"].get_data()
|
||||
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
downsample_ratio = int(
|
||||
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
|
||||
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
|
||||
chunks_per_image = prod(mm_kwargs["patches_per_image"])
|
||||
chunks_per_image = prod(mm_data["patches_per_image"])
|
||||
total_num_patches = chunks_per_image * tokens_per_patch
|
||||
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
|
||||
0][1] # x-y seperator tokens
|
||||
num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][
|
||||
1] # x-y seperator tokens
|
||||
total_tokens = total_num_patches.item() + num_tiles.item(
|
||||
) + 3 # image start, image, image end
|
||||
|
||||
|
||||
@@ -70,7 +70,8 @@ def _run_check(
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
|
||||
pixel_shape = processed_inputs["mm_kwargs"].get_data(
|
||||
)["pixel_values_flat"].shape
|
||||
print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape)
|
||||
assert img_tok_count == 256 * total_expected_num_patches
|
||||
assert pixel_shape[0] == total_expected_num_patches
|
||||
|
||||
@@ -48,7 +48,8 @@ def test_processor_override(
|
||||
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
|
||||
pixel_shape = processed_inputs["mm_kwargs"].get_data(
|
||||
)["pixel_values"].shape
|
||||
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
|
||||
|
||||
Reference in New Issue
Block a user