[Refactor] Move top-level dummy data generation to registry (#32310)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-14 18:17:46 +08:00
committed by GitHub
parent b8199f6049
commit 90db5b31e4
6 changed files with 57 additions and 132 deletions

View File

@@ -24,32 +24,20 @@ def test_profiling(model_id: str, max_model_len: int):
limit_mm_per_prompt=mm_counts,
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
processor,
max_model_len,
mm_counts=mm_counts,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_inputs = MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
ctx.model_config,
mm_counts=mm_counts,
)
hf_config = ctx.get_hf_config(Llama4Config)
mm_inputs = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)
mm_data = mm_inputs["mm_kwargs"].get_data()
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))
)
tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio
mm_data = mm_inputs["mm_kwargs"].get_data()
chunks_per_image = prod(mm_data["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = (
@@ -63,6 +51,5 @@ def test_profiling(model_id: str, max_model_len: int):
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"]
)
assert total_tokens == sum(
placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"]
)