[V1] Add disable_chunked_mm_input arg to disable partial mm input prefill (#15837)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -24,6 +24,7 @@ def create_scheduler(
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_prefix_caching: Optional[bool] = None,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
disable_chunked_mm_input: bool = False,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
@@ -43,6 +44,7 @@ def create_scheduler(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
@@ -278,6 +280,49 @@ def test_schedule_partial_requests():
|
||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
||||
|
||||
|
||||
def test_no_mm_input_chunking():
|
||||
# Disable multimodal input chunking.
|
||||
scheduler = create_scheduler(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_num_batched_tokens=1024,
|
||||
disable_chunked_mm_input=True,
|
||||
)
|
||||
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
|
||||
requests = create_requests(num_requests=1,
|
||||
num_tokens=1200,
|
||||
mm_positions=mm_positions)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert len(output.scheduled_cached_reqs) == 0
|
||||
assert len(output.finished_req_ids) == 0
|
||||
# We want to only see the 400 text tokens at the start scheduled
|
||||
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
||||
|
||||
req_to_index = {
|
||||
request.request_id: i
|
||||
for i, request in enumerate(requests)
|
||||
}
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert len(output.scheduled_cached_reqs) == 1
|
||||
assert len(output.finished_req_ids) == 0
|
||||
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
"""Test scheduling behavior with concurrent partial requests.
|
||||
|
||||
Reference in New Issue
Block a user