[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -45,7 +45,7 @@ class AsyncLLM:
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -66,7 +66,7 @@ class AsyncLLM:
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
engine_use_ray=True,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
|
||||
@@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int):
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
final_seq_lens = [
|
||||
final_prompt_lens = [
|
||||
len(prompt + output) + num_steps
|
||||
for prompt, output in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
@@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int):
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_seq_lens,
|
||||
final_prompt_lens,
|
||||
continuations=prev_output_tokens)
|
||||
|
||||
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
|
||||
@@ -103,17 +103,21 @@ def test_same_output_for_single_step():
|
||||
[6, 7, 8, 9, 10],
|
||||
]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
|
||||
random.randint(0, 1000) for _ in range(random.randint(10, 20))
|
||||
] for _ in range(10)]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
@@ -195,7 +199,7 @@ def test_same_output_for_multi_step():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens), )
|
||||
final_prompt_lens=final_prompt_lens), )
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
@@ -217,7 +221,7 @@ def test_same_output_for_multi_step():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens))
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
||||
|
||||
@@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match():
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
@@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
@@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
|
||||
@@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
|
||||
prompts: List[List[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_seq_lens: List[int],
|
||||
final_prompt_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
@@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_seq_lens)
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
return [
|
||||
@@ -251,13 +251,13 @@ def create_batch(batch_size,
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_seq_lens = [
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_seq_lens,
|
||||
block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
||||
|
||||
Reference in New Issue
Block a user