[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)

This commit is contained in:
youkaichao
2024-05-08 12:07:05 -07:00
committed by GitHub
parent ad932a221d
commit 20cfcdec99
21 changed files with 137 additions and 109 deletions

View File

@@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 2
assert out.num_batched_tokens == 2
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
append_new_token(out, 1)
# Add 1 more task. Swap should be prioritized over prefill.
@@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
assert len(out.scheduled_seq_groups) == 3
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 3
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def initialize_scheduler(*,
@@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
# assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
assert output.blocks_to_swap_out == []
# Nothing is copied.
assert output.blocks_to_copy == []
@@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
scheduler.block_manager.swap_out = MagicMock()
expected_swap_mapping = {"5": "7"}
expected_swap_mapping = [("5", "7")]
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
remainig_running, output = scheduler._schedule_running(
@@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert len(output.preempted) == 0
assert len(output.swapped_out) == 0
# Nothing is preempted.
assert output.blocks_to_swap_out == {}
assert output.blocks_to_swap_out == []
# Since append_slot returns the source -> dist mapping, it should
# applied.
assert output.blocks_to_copy == [(2, 3)]
@@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
@@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
# swap in is the reverse of swap out
blocks_to_swap_in_reverse = {}
for swapin, swapout in output.blocks_to_swap_in.items():
blocks_to_swap_in_reverse[swapout] = swapin
blocks_to_swap_in_reverse = []
for swapin, swapout in output.blocks_to_swap_in:
blocks_to_swap_in_reverse.append((swapout, swapin))
assert blocks_to_swap_out == blocks_to_swap_in_reverse
@@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group)
@@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = set()
blocks_to_swap_out = {}
blocks_to_swap_out = []
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
@@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@@ -808,7 +808,7 @@ def test_infeasible_swap():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out = {}
blocks_to_swap_out = []
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)