[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

This commit is contained in:
SangBin Cho
2024-05-04 02:20:12 +09:00
committed by GitHub
parent 2d7bce9cd5
commit 3521ba4f25
27 changed files with 554 additions and 525 deletions

View File

@@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
return logits
seq_group_metadata_list = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str):
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor(