[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str):
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
logits_processor_output = logits_processor(
|
||||
|
||||
Reference in New Issue
Block a user