[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int):
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
final_seq_lens = [
|
||||
final_prompt_lens = [
|
||||
len(prompt + output) + num_steps
|
||||
for prompt, output in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
@@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int):
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_seq_lens,
|
||||
final_prompt_lens,
|
||||
continuations=prev_output_tokens)
|
||||
|
||||
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
|
||||
@@ -103,17 +103,21 @@ def test_same_output_for_single_step():
|
||||
[6, 7, 8, 9, 10],
|
||||
]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
|
||||
random.randint(0, 1000) for _ in range(random.randint(10, 20))
|
||||
] for _ in range(10)]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
@@ -195,7 +199,7 @@ def test_same_output_for_multi_step():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens), )
|
||||
final_prompt_lens=final_prompt_lens), )
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
@@ -217,7 +221,7 @@ def test_same_output_for_multi_step():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens))
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
||||
|
||||
Reference in New Issue
Block a user