[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -74,11 +74,11 @@ def mm_model_cls():
|
||||
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
||||
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
|
||||
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
|
||||
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
||||
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
||||
}
|
||||
|
||||
|
||||
### Test for default processor logic & mm_processor_kwargs wrapping
|
||||
### Tests for default processor logic & mm_processor_kwargs wrapping
|
||||
def test_default_processor_is_a_noop():
|
||||
"""Ensure that by default, there is no processor override."""
|
||||
dummy_registry = InputRegistry()
|
||||
@@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
|
||||
assert proc_inputs is proc_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_processor_default_kwargs(use_processor_mock, num_crops):
|
||||
"""Ensure input processors can use processor kwargs."""
|
||||
dummy_registry = InputRegistry()
|
||||
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
|
||||
"""Get the init / inference kwargs and expected num_crops for this test."""
|
||||
# If we have a value for num_crops, pass the override value and make
|
||||
# sure we get that value as a return-value from out mock processor,
|
||||
# otherwise fall back to the default value
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
init_kwargs = None if init_num_crops is None else {
|
||||
"num_crops": init_num_crops
|
||||
}
|
||||
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
inference_kwargs = None if inference_num_crops is None else {
|
||||
"num_crops": inference_num_crops
|
||||
}
|
||||
if inference_num_crops is not None:
|
||||
expected_seq_count = inference_num_crops
|
||||
elif init_num_crops is not None:
|
||||
expected_seq_count = init_num_crops
|
||||
else:
|
||||
expected_seq_count = DEFAULT_NUM_CROPS
|
||||
return init_kwargs, inference_kwargs, expected_seq_count
|
||||
|
||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||
assert num_crops_val == expected_num_crops
|
||||
|
||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||
(None, None),
|
||||
(NUM_CROPS_OVERRIDE, None),
|
||||
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
||||
])
|
||||
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
|
||||
inference_num_crops):
|
||||
"""Ensure input processors can use processor kwargs."""
|
||||
dummy_registry = InputRegistry()
|
||||
|
||||
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
||||
init_num_crops, inference_num_crops)
|
||||
|
||||
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(
|
||||
LLMInputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
assert num_crops_val == expected_seq_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
||||
mm_processor_kwargs):
|
||||
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
|
||||
dummy_registry = InputRegistry()
|
||||
# Should filter out the init time kwargs
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||
# Should filter out the inference time kwargs
|
||||
num_crops_val = processor(
|
||||
LLMInputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=mm_processor_kwargs))
|
||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
@@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
|
||||
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||
(None, None),
|
||||
(NUM_CROPS_OVERRIDE, None),
|
||||
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
||||
])
|
||||
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
|
||||
inference_num_crops):
|
||||
"""Ensure custom mappers can use processor kwargs."""
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
}
|
||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
||||
init_num_crops, inference_num_crops)
|
||||
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_processor_kwargs=init_kwargs,
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": image}
|
||||
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_default_input_mapper",
|
||||
{mm_model_cls(): custom_mapper},
|
||||
):
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||
mm_model_cls())
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
|
||||
inference_kwargs)
|
||||
|
||||
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
|
||||
|
||||
@@ -316,6 +346,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
||||
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
||||
mm_processor_kwargs):
|
||||
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
||||
# Should filter out the init time kwargs
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
@@ -323,17 +354,16 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": image}
|
||||
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_default_input_mapper",
|
||||
{mm_model_cls(): custom_mapper},
|
||||
):
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||
mm_model_cls())
|
||||
# Should filter out the inference time kwargs
|
||||
mapped_inputs = mm_registry.map_input(
|
||||
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|
||||
|
||||
Reference in New Issue
Block a user