[Optimization] Cache sampled token ids in model runner (#20291)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -172,7 +172,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
req_state.block_ids[0]).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
def test_update_states_new_request(model_runner, dist_init):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
@@ -186,7 +186,7 @@ def test_update_states_new_request(model_runner):
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_finished(model_runner):
|
||||
def test_update_states_request_finished(model_runner, dist_init):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
@@ -218,7 +218,7 @@ def test_update_states_request_finished(model_runner):
|
||||
assert not _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_resumed(model_runner):
|
||||
def test_update_states_request_resumed(model_runner, dist_init):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
@@ -278,7 +278,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_get_nans_in_logits(model_runner):
|
||||
def test_get_nans_in_logits(model_runner, dist_init):
|
||||
req_ids = ("req_0", "req_1")
|
||||
|
||||
scheduler_output = _schedule_new_request(*req_ids)
|
||||
@@ -326,7 +326,7 @@ def test_get_nans_in_logits(model_runner):
|
||||
assert result == {'req_0': 2, 'req_1': 0}
|
||||
|
||||
|
||||
def test_update_states_no_changes(model_runner):
|
||||
def test_update_states_no_changes(model_runner, dist_init):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
@@ -359,7 +359,7 @@ def test_update_states_no_changes(model_runner):
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_unscheduled(model_runner):
|
||||
def test_update_states_request_unscheduled(model_runner, dist_init):
|
||||
req_ids = ("req_0", "req_1")
|
||||
|
||||
# new reqs
|
||||
|
||||
Reference in New Issue
Block a user