[MRV2] Consider spec decoding in warmup (#37812)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -35,7 +35,9 @@ def warmup_kernels(
|
||||
"""
|
||||
prompt_token_ids = [0, 1]
|
||||
prompt_len = len(prompt_token_ids)
|
||||
decode_len = prompt_len + 1 # After prefill, one decode token is added.
|
||||
num_spec_steps = model_runner.num_speculative_steps
|
||||
# After prefill, decode generates 1 verified + num_spec_steps draft tokens.
|
||||
decode_len = prompt_len + 1 + num_spec_steps
|
||||
|
||||
kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups
|
||||
num_kv_cache_groups = len(kv_cache_groups)
|
||||
@@ -51,7 +53,8 @@ def warmup_kernels(
|
||||
|
||||
num_reqs = min(
|
||||
model_runner.scheduler_config.max_num_seqs,
|
||||
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
|
||||
model_runner.scheduler_config.max_num_batched_tokens
|
||||
// max(prompt_len, 1 + num_spec_steps),
|
||||
# Reserve block 0 (null block) and ensure we have enough blocks.
|
||||
max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req),
|
||||
)
|
||||
@@ -111,7 +114,7 @@ def warmup_kernels(
|
||||
|
||||
worker_sample_tokens(grammar_output)
|
||||
|
||||
# Step 2: Decode all requests with 1 token each.
|
||||
# Step 2: Decode all requests with 1 + num_spec_steps tokens each.
|
||||
cached_req_data = CachedRequestData.make_empty()
|
||||
cached_req_data.req_ids = list(req_ids)
|
||||
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
|
||||
@@ -124,8 +127,16 @@ def warmup_kernels(
|
||||
|
||||
decode_output = SchedulerOutput.make_empty()
|
||||
decode_output.scheduled_cached_reqs = cached_req_data
|
||||
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
|
||||
decode_output.total_num_scheduled_tokens = num_reqs
|
||||
decode_output.num_scheduled_tokens = {
|
||||
req_id: 1 + num_spec_steps for req_id in req_ids
|
||||
}
|
||||
if num_spec_steps > 0:
|
||||
decode_output.scheduled_spec_decode_tokens = {
|
||||
req_id: [0] * num_spec_steps for req_id in req_ids
|
||||
}
|
||||
decode_output.total_num_scheduled_tokens = sum(
|
||||
decode_output.num_scheduled_tokens.values()
|
||||
)
|
||||
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||
|
||||
worker_execute_model(decode_output)
|
||||
|
||||
Reference in New Issue
Block a user