[Refactor] Move top-level dummy data generation to registry (#32310)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -24,32 +24,20 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
limit_mm_per_prompt=mm_counts,
|
||||
)
|
||||
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data(
|
||||
processor,
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||
max_model_len,
|
||||
mm_inputs = MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
|
||||
ctx.model_config,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
hf_config = ctx.get_hf_config(Llama4Config)
|
||||
|
||||
mm_inputs = processor.apply(
|
||||
prompt=dummy_mm_data.prompt,
|
||||
mm_data=dummy_mm_data.mm_data,
|
||||
hf_processor_mm_kwargs=dict(),
|
||||
)
|
||||
mm_data = mm_inputs["mm_kwargs"].get_data()
|
||||
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
downsample_ratio = int(
|
||||
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))
|
||||
)
|
||||
tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio
|
||||
|
||||
mm_data = mm_inputs["mm_kwargs"].get_data()
|
||||
chunks_per_image = prod(mm_data["patches_per_image"])
|
||||
total_num_patches = chunks_per_image * tokens_per_patch
|
||||
num_tiles = (
|
||||
@@ -63,6 +51,5 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"]
|
||||
)
|
||||
assert total_tokens == sum(
|
||||
placeholder.length
|
||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||
placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user