[Optimization] Cache sampled token ids in model runner (#20291)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-07-01 11:01:31 -07:00
committed by GitHub
parent 02cabff207
commit 7f280d69c9
5 changed files with 91 additions and 45 deletions

View File

@@ -172,7 +172,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state.block_ids[0]).all()
def test_update_states_new_request(model_runner):
def test_update_states_new_request(model_runner, dist_init):
req_id = "req_0"
# new req
@@ -186,7 +186,7 @@ def test_update_states_new_request(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
def test_update_states_request_finished(model_runner, dist_init):
req_id = "req_0"
# new req
@@ -218,7 +218,7 @@ def test_update_states_request_finished(model_runner):
assert not _is_req_scheduled(model_runner, req_id)
def test_update_states_request_resumed(model_runner):
def test_update_states_request_resumed(model_runner, dist_init):
req_id = "req_0"
# new req
@@ -278,7 +278,7 @@ def test_update_states_request_resumed(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_get_nans_in_logits(model_runner):
def test_get_nans_in_logits(model_runner, dist_init):
req_ids = ("req_0", "req_1")
scheduler_output = _schedule_new_request(*req_ids)
@@ -326,7 +326,7 @@ def test_get_nans_in_logits(model_runner):
assert result == {'req_0': 2, 'req_1': 0}
def test_update_states_no_changes(model_runner):
def test_update_states_no_changes(model_runner, dist_init):
req_id = "req_0"
# new req
@@ -359,7 +359,7 @@ def test_update_states_no_changes(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):
def test_update_states_request_unscheduled(model_runner, dist_init):
req_ids = ("req_0", "req_1")
# new reqs