[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size):
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
block_tables = {0: [1]}
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = SequenceData(list(range(prompt_len)))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size):
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for prompt_len in prompt_lens:
|
||||
for seq_len in seq_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += prompt_len
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is True
|
||||
assert torch.allclose(attn_metadata.prompt_lens_tensor,
|
||||
torch.tensor(prompt_lens, device=device))
|
||||
assert attn_metadata.prompt_lens == prompt_lens
|
||||
assert attn_metadata.max_prompt_len == max(prompt_lens)
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
|
||||
# Test subquery start locs.
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for prompt_len in prompt_lens:
|
||||
start_idx += prompt_len
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.subquery_start_loc,
|
||||
@@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size):
|
||||
# equivalent to subquery_start_loc.
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for prompt_len in prompt_lens:
|
||||
start_idx += prompt_len
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
assert attn_metadata.max_context_len is None
|
||||
assert torch.allclose(
|
||||
attn_metadata.context_lens,
|
||||
torch.zeros(attn_metadata.context_lens.shape[0],
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
||||
dtype=torch.int,
|
||||
device=device))
|
||||
|
||||
@@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size):
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
@@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = list(range(seq_len))
|
||||
seq_data = SequenceData(seq_data)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is False
|
||||
assert attn_metadata.prompt_lens is None
|
||||
assert attn_metadata.max_prompt_len is None
|
||||
assert attn_metadata.seq_lens is None
|
||||
assert attn_metadata.subquery_start_loc is None
|
||||
assert attn_metadata.seq_start_loc is None
|
||||
assert attn_metadata.max_context_len == max(prompt_lens)
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
assert torch.allclose(
|
||||
attn_metadata.context_lens[:len(prompt_lens)],
|
||||
torch.tensor(prompt_lens, dtype=torch.int, device=device))
|
||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||
|
||||
# block table's first index corresponds to each batch, meaning in
|
||||
# decoding it is each token.
|
||||
@@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for prompt_len in prompt_lens:
|
||||
for seq_len in seq_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
selected_token_start_idx += 1
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
@@ -241,14 +239,13 @@ def test_empty_seq_group():
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
assert len(return_prompt_lens) == 0
|
||||
assert len(return_seq_lens) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
# Add prefill requests.
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
prefill_metadata_list = []
|
||||
decode_metadata_list = []
|
||||
@@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
decode_batch_size = batch_size - prefill_batch_size
|
||||
for i in range(prefill_batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = SequenceData(list(range(prompt_len)))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
# Add decode requests
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(prompt_len))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(seq_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
else:
|
||||
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
|
||||
decode_batch_size)
|
||||
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
|
||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
|
||||
Reference in New Issue
Block a user