[MRV2] Consider spec decoding in warmup (#37812)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Woosuk Kwon
2026-03-23 10:45:43 -07:00
committed by GitHub
parent 91fd695b75
commit ffb5b32b5f

View File

@@ -35,7 +35,9 @@ def warmup_kernels(
"""
prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids)
decode_len = prompt_len + 1 # After prefill, one decode token is added.
num_spec_steps = model_runner.num_speculative_steps
# After prefill, decode generates 1 verified + num_spec_steps draft tokens.
decode_len = prompt_len + 1 + num_spec_steps
kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups
num_kv_cache_groups = len(kv_cache_groups)
@@ -51,7 +53,8 @@ def warmup_kernels(
num_reqs = min(
model_runner.scheduler_config.max_num_seqs,
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
model_runner.scheduler_config.max_num_batched_tokens
// max(prompt_len, 1 + num_spec_steps),
# Reserve block 0 (null block) and ensure we have enough blocks.
max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req),
)
@@ -111,7 +114,7 @@ def warmup_kernels(
worker_sample_tokens(grammar_output)
# Step 2: Decode all requests with 1 token each.
# Step 2: Decode all requests with 1 + num_spec_steps tokens each.
cached_req_data = CachedRequestData.make_empty()
cached_req_data.req_ids = list(req_ids)
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
@@ -124,8 +127,16 @@ def warmup_kernels(
decode_output = SchedulerOutput.make_empty()
decode_output.scheduled_cached_reqs = cached_req_data
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
decode_output.total_num_scheduled_tokens = num_reqs
decode_output.num_scheduled_tokens = {
req_id: 1 + num_spec_steps for req_id in req_ids
}
if num_spec_steps > 0:
decode_output.scheduled_spec_decode_tokens = {
req_id: [0] * num_spec_steps for req_id in req_ids
}
decode_output.total_num_scheduled_tokens = sum(
decode_output.num_scheduled_tokens.values()
)
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
worker_execute_model(decode_output)