[Core][KVConnector] Propagate all tokens on resumed preemptions (#24926)
Signed-off-by: Qier Li <kevin44036@gmail.com> Co-authored-by: Qier Li <qier@fb.com>
This commit is contained in:
@@ -1950,7 +1950,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
|
||||
def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
|
||||
"""Test that priority scheduling preempts lower priority requests
|
||||
when out of KV cache space."""
|
||||
# Create scheduler with very limited memory to force preemption
|
||||
@@ -1959,6 +1959,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
max_num_batched_tokens=200,
|
||||
num_blocks=5, # Can hold 64 tokens (first block is null)
|
||||
block_size=16, # Standard block size
|
||||
use_kv_connector=True,
|
||||
)
|
||||
|
||||
# Create a request and schedule it
|
||||
@@ -1970,12 +1971,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
starting_idx=0,
|
||||
)[0]
|
||||
scheduler.add_request(request_low)
|
||||
# 1st schedule
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Simulate model execution
|
||||
# Simulate model execution - 1st decode
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[request_low.request_id],
|
||||
req_id_to_index={request_low.request_id: 0},
|
||||
@@ -1996,6 +1998,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
starting_idx=1,
|
||||
)[0]
|
||||
scheduler.add_request(request_high)
|
||||
# 2nd schedule
|
||||
output = scheduler.schedule()
|
||||
# KV cache should be full at this point
|
||||
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
|
||||
@@ -2004,7 +2007,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 2
|
||||
|
||||
# Simulate model execution
|
||||
# Simulate model execution - 2nd decode
|
||||
requests = [request_low, request_high]
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
@@ -2017,7 +2020,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
)
|
||||
scheduler.update_from_output(output, model_output)
|
||||
|
||||
# Schedule again - this should trigger preemption
|
||||
# 3rd schedule - this should trigger preemption
|
||||
# req_low needs 32 tokens = 2 blocks
|
||||
# req_high needs 33 tokens = 3 blocks
|
||||
# so doesn't fit in 4 blocks.
|
||||
@@ -2027,9 +2030,44 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert output.scheduled_cached_reqs.num_reqs == 1
|
||||
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
|
||||
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Simulate model execution - 3rd decode
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[[], [100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
# Finish the requests to make room for the preempted requests to resume
|
||||
scheduler.update_from_output(output, model_output)
|
||||
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
|
||||
|
||||
# 4th Schedule - this should trigger the resumption
|
||||
output = scheduler.schedule()
|
||||
scheduled_cached_reqs = output.scheduled_cached_reqs
|
||||
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
|
||||
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert scheduled_cached_reqs.num_reqs == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Preempted request resumed in scheduled_cached_reqs
|
||||
assert len(resumed_from_preemption) == 1
|
||||
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
|
||||
assert resumed_from_preemption[0]
|
||||
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
|
||||
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
|
||||
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
|
||||
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
|
||||
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
|
||||
|
||||
Reference in New Issue
Block a user