[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
|
||||
prompts: List[List[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_seq_lens: List[int],
|
||||
final_prompt_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
@@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_seq_lens)
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
return [
|
||||
@@ -251,13 +251,13 @@ def create_batch(batch_size,
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_seq_lens = [
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_seq_lens,
|
||||
block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
||||
|
||||
Reference in New Issue
Block a user