From ffb5b32b5f2a37ea58261747aef9e5f3907b9941 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 23 Mar 2026 10:45:43 -0700 Subject: [PATCH] [MRV2] Consider spec decoding in warmup (#37812) Signed-off-by: Woosuk Kwon Signed-off-by: Nick Hill Co-authored-by: Nick Hill --- vllm/v1/worker/gpu/warmup.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu/warmup.py b/vllm/v1/worker/gpu/warmup.py index 28e480134..026b6a7d7 100644 --- a/vllm/v1/worker/gpu/warmup.py +++ b/vllm/v1/worker/gpu/warmup.py @@ -35,7 +35,9 @@ def warmup_kernels( """ prompt_token_ids = [0, 1] prompt_len = len(prompt_token_ids) - decode_len = prompt_len + 1 # After prefill, one decode token is added. + num_spec_steps = model_runner.num_speculative_steps + # After prefill, decode generates 1 verified + num_spec_steps draft tokens. + decode_len = prompt_len + 1 + num_spec_steps kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups num_kv_cache_groups = len(kv_cache_groups) @@ -51,7 +53,8 @@ def warmup_kernels( num_reqs = min( model_runner.scheduler_config.max_num_seqs, - model_runner.scheduler_config.max_num_batched_tokens // prompt_len, + model_runner.scheduler_config.max_num_batched_tokens + // max(prompt_len, 1 + num_spec_steps), # Reserve block 0 (null block) and ensure we have enough blocks. max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req), ) @@ -111,7 +114,7 @@ def warmup_kernels( worker_sample_tokens(grammar_output) - # Step 2: Decode all requests with 1 token each. + # Step 2: Decode all requests with 1 + num_spec_steps tokens each. cached_req_data = CachedRequestData.make_empty() cached_req_data.req_ids = list(req_ids) cached_req_data.num_computed_tokens = [prompt_len] * num_reqs @@ -124,8 +127,16 @@ def warmup_kernels( decode_output = SchedulerOutput.make_empty() decode_output.scheduled_cached_reqs = cached_req_data - decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids} - decode_output.total_num_scheduled_tokens = num_reqs + decode_output.num_scheduled_tokens = { + req_id: 1 + num_spec_steps for req_id in req_ids + } + if num_spec_steps > 0: + decode_output.scheduled_spec_decode_tokens = { + req_id: [0] * num_spec_steps for req_id in req_ids + } + decode_output.total_num_scheduled_tokens = sum( + decode_output.num_scheduled_tokens.values() + ) decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups worker_execute_model(decode_output)