[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

This commit is contained in:
SangBin Cho
2024-05-04 02:20:12 +09:00
committed by GitHub
parent 2d7bce9cd5
commit 3521ba4f25
27 changed files with 554 additions and 525 deletions

View File

@@ -45,7 +45,7 @@ class AsyncLLM:
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
@@ -66,7 +66,7 @@ class AsyncLLM:
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
engine_use_ray=True,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,

View File

@@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int):
list(range(block_size * 2)),
]
final_seq_lens = [
final_prompt_lens = [
len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens)
]
@@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int):
prompts,
num_gpu_blocks,
block_size,
final_seq_lens,
final_prompt_lens,
continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
@@ -103,17 +103,21 @@ def test_same_output_for_single_step():
[6, 7, 8, 9, 10],
]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
single_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
@@ -195,7 +199,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks,
block_size,
continuations=continuations,
final_seq_lens=final_seq_lens), )
final_prompt_lens=final_prompt_lens), )
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
@@ -217,7 +221,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks,
block_size,
continuations=continuations,
final_seq_lens=final_seq_lens))
final_prompt_lens=final_prompt_lens))
single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), ))

View File

@@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match():
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
@@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
@@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all():
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),

View File

@@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
prompts: List[List[int]],
num_gpu_blocks: int,
block_size: int,
final_seq_lens: List[int],
final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
@@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_seq_lens)
for i, final_len in enumerate(final_prompt_lens)
}
return [
@@ -251,13 +251,13 @@ def create_batch(batch_size,
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_seq_lens = [
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
block_size, final_seq_lens,
block_size, final_prompt_lens,
prev_output_tokens, seq_ids), )
return execute_model_data, prompts, prev_output_tokens