Add embedding input functionality for disabled modalities [remake] (#32493)
Signed-off-by: Reagan Lee <“reaganjlee@gmail.com”> Signed-off-by: Reagan Lee <reaganjlee@gmail.com> Signed-off-by: Reagan Lee <96998476+reaganjlee@users.noreply.github.com> Co-authored-by: Reagan Lee <“reaganjlee@gmail.com”> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -899,40 +899,6 @@ def test_find_mm_placeholders(
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("limit", "num_supported", "is_valid"),
|
||||
[
|
||||
(0, 0, True),
|
||||
(0, 1, True),
|
||||
(1, 0, False),
|
||||
(1, 1, True),
|
||||
(1, 2, True),
|
||||
(2, 1, False),
|
||||
(2, 2, True),
|
||||
],
|
||||
)
|
||||
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
limit_mm_per_prompt = {"image": limit}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_id,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
|
||||
processor.info.get_supported_mm_limits = lambda: {"image": num_supported}
|
||||
|
||||
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
|
||||
|
||||
with exc_ctx:
|
||||
MULTIMODAL_REGISTRY.get_dummy_mm_inputs(
|
||||
model_config,
|
||||
mm_counts=limit_mm_per_prompt,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("num_images", "limit", "is_valid"),
|
||||
@@ -975,6 +941,50 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("user_limit", "supported_limit"),
|
||||
[
|
||||
(0, 0),
|
||||
(0, 1),
|
||||
(1, 0), # user wants 1, model supports 0 → capped to 0
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(2, 1), # user wants 2, model supports 1 → capped to 1
|
||||
(2, 2),
|
||||
(5, 1), # large user limit, low model support → capped to 1
|
||||
(1, 5),
|
||||
(10, 0), # large user limit, no model support → capped to 0
|
||||
],
|
||||
)
|
||||
def test_budget_caps_prevent_dummy_input_validation_failure(
|
||||
model_id, user_limit, supported_limit
|
||||
):
|
||||
limit_mm_per_prompt = {"image": user_limit}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_id,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
|
||||
processor.info.get_supported_mm_limits = lambda: {"image": supported_limit}
|
||||
|
||||
# This is what budget.py uses to derive mm_counts
|
||||
allowed = processor.info.allowed_mm_limits
|
||||
|
||||
assert allowed["image"] <= supported_limit, (
|
||||
f"allowed_mm_limits['image']={allowed['image']} exceeds "
|
||||
f"supported_limit={supported_limit}"
|
||||
)
|
||||
|
||||
assert allowed["image"] <= user_limit, (
|
||||
f"allowed_mm_limits['image']={allowed['image']} exceeds user_limit={user_limit}"
|
||||
)
|
||||
|
||||
assert allowed["image"] == min(user_limit, supported_limit)
|
||||
|
||||
|
||||
class DummyProcessor:
|
||||
def __init__(self, a: int = 0, b: int = 0) -> None:
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user