[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)
This commit is contained in:
@@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert out.num_batched_tokens == 2
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over prefill.
|
||||
@@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 3
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def initialize_scheduler(*,
|
||||
@@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
|
||||
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
|
||||
# assert budget.num_curr_seqs == 1
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
assert output.blocks_to_swap_out == []
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
@@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
scheduler.block_manager.swap_out = MagicMock()
|
||||
expected_swap_mapping = {"5": "7"}
|
||||
expected_swap_mapping = [("5", "7")]
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
@@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 0
|
||||
# Nothing is preempted.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
assert output.blocks_to_swap_out == []
|
||||
# Since append_slot returns the source -> dist mapping, it should
|
||||
# applied.
|
||||
assert output.blocks_to_copy == [(2, 3)]
|
||||
@@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
@@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
# swap in is the reverse of swap out
|
||||
blocks_to_swap_in_reverse = {}
|
||||
for swapin, swapout in output.blocks_to_swap_in.items():
|
||||
blocks_to_swap_in_reverse[swapout] = swapin
|
||||
blocks_to_swap_in_reverse = []
|
||||
for swapin, swapout in output.blocks_to_swap_in:
|
||||
blocks_to_swap_in_reverse.append((swapout, swapin))
|
||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||
|
||||
|
||||
@@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = set()
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
@@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -808,7 +808,7 @@ def test_infeasible_swap():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user