[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

This commit is contained in:
SangBin Cho
2024-05-04 02:20:12 +09:00
committed by GitHub
parent 2d7bce9cd5
commit 3521ba4f25
27 changed files with 554 additions and 525 deletions

View File

@@ -58,7 +58,7 @@ def _do_sample(
device: str,
):
seq_group_metadata_list = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@@ -68,12 +68,12 @@ def _do_sample(
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
@@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"Invalid test case, need seq_group_metadata_list"
batch_size = 0
prompt_lens = []
seq_lens = []
sampling_params_per_row = []
for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params
@@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len()
prompt_lens.append(prompt_len)
seq_lens.append(prompt_len)
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
@@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens=prompt_lens if prompt_lens else None,
subquery_lens=prompt_lens if prompt_lens else None,
seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None,
device=device,
pin_memory=model_runner.pin_memory)
# the logits tensor is modified in-place by the sampler
@@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_group_metadata_list = []
expected_tokens: List[Optional[List[int]]] = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
@@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str):
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
def test_sampling(model_runner: ModelRunner):
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
sampler_output = sampler(logits=fake_logits,
@@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
expected_tokens, seq_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
@@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)