[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -58,7 +58,7 @@ def _do_sample(
|
||||
device: str,
|
||||
):
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -68,12 +68,12 @@ def _do_sample(
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
@@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
sampling_params_per_row = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
sampling_params = sgm.sampling_params
|
||||
@@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_lens.append(prompt_len)
|
||||
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
@@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens=prompt_lens if prompt_lens else None,
|
||||
subquery_lens=prompt_lens if prompt_lens else None,
|
||||
seq_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else None,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
@@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens: List[Optional[List[int]]] = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
expected: Optional[List[int]] = None
|
||||
sampling_type = random.randint(0, 3)
|
||||
@@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
def test_sampling(model_runner: ModelRunner):
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
@@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, prompt_lens):
|
||||
expected_tokens, seq_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
@@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
assert len(warpers) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user