[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

This commit is contained in:
SangBin Cho
2024-05-04 02:20:12 +09:00
committed by GitHub
parent 2d7bce9cd5
commit 3521ba4f25
27 changed files with 554 additions and 525 deletions

View File

@@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
prompts: List[List[int]],
num_gpu_blocks: int,
block_size: int,
final_seq_lens: List[int],
final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
@@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_seq_lens)
for i, final_len in enumerate(final_prompt_lens)
}
return [
@@ -251,13 +251,13 @@ def create_batch(batch_size,
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_seq_lens = [
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
block_size, final_seq_lens,
block_size, final_prompt_lens,
prev_output_tokens, seq_ids), )
return execute_model_data, prompts, prev_output_tokens