[Core][KVConnector] Propagate all tokens on resumed preemptions (#24926)

Signed-off-by: Qier Li <kevin44036@gmail.com>
Co-authored-by: Qier Li <qier@fb.com>
This commit is contained in:
Qier Li
2025-10-09 02:43:31 -04:00
committed by GitHub
parent 43ab8cfaa5
commit d17f0fbf30
4 changed files with 60 additions and 9 deletions

View File

@@ -1950,7 +1950,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
assert len(scheduler.waiting) == 1
def test_priority_scheduling_preemption_when_out_of_kv():
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
@@ -1959,6 +1959,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
max_num_batched_tokens=200,
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
use_kv_connector=True,
)
# Create a request and schedule it
@@ -1970,12 +1971,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
starting_idx=0,
)[0]
scheduler.add_request(request_low)
# 1st schedule
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Simulate model execution
# Simulate model execution - 1st decode
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
@@ -1996,6 +1998,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
starting_idx=1,
)[0]
scheduler.add_request(request_high)
# 2nd schedule
output = scheduler.schedule()
# KV cache should be full at this point
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
@@ -2004,7 +2007,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 2
# Simulate model execution
# Simulate model execution - 2nd decode
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@@ -2017,7 +2020,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
)
scheduler.update_from_output(output, model_output)
# Schedule again - this should trigger preemption
# 3rd schedule - this should trigger preemption
# req_low needs 32 tokens = 2 blocks
# req_high needs 33 tokens = 3 blocks
# so doesn't fit in 4 blocks.
@@ -2027,9 +2030,44 @@ def test_priority_scheduling_preemption_when_out_of_kv():
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
# Simulate model execution - 3rd decode
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[], [100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Finish the requests to make room for the preempted requests to resume
scheduler.update_from_output(output, model_output)
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
# 4th Schedule - this should trigger the resumption
output = scheduler.schedule()
scheduled_cached_reqs = output.scheduled_cached_reqs
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
assert len(output.scheduled_new_reqs) == 0
assert scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Preempted request resumed in scheduled_cached_reqs
assert len(resumed_from_preemption) == 1
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
assert resumed_from_preemption[0]
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
@pytest.mark.parametrize(
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),