From a73af584fe6d4c1c2781d537c35e3cc85f58480b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 5 Mar 2026 14:48:10 -0800 Subject: [PATCH] [Model Runner V2] Fix warmup for very small kvcache and/or blocksizes (#36176) Signed-off-by: Nick Hill --- vllm/v1/worker/gpu/warmup.py | 34 ++++++++++++++++++++++++++++++---- vllm/v1/worker/gpu_worker.py | 4 ++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu/warmup.py b/vllm/v1/worker/gpu/warmup.py index 9d70a56f5..082b4e642 100644 --- a/vllm/v1/worker/gpu/warmup.py +++ b/vllm/v1/worker/gpu/warmup.py @@ -5,6 +5,7 @@ import numpy as np import torch from vllm import PoolingParams, SamplingParams +from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import ( CachedRequestData, GrammarOutput, @@ -26,12 +27,27 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None: """ prompt_token_ids = [0, 1] prompt_len = len(prompt_token_ids) + decode_len = prompt_len + 1 # After prefill, one decode token is added. + + kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups + num_kv_cache_groups = len(kv_cache_groups) + + # Compute per-request block counts for each KV cache group. + group_block_sizes = [g.kv_cache_spec.block_size for g in kv_cache_groups] + prefill_block_counts = [cdiv(prompt_len, bs) for bs in group_block_sizes] + decode_block_counts = [cdiv(decode_len, bs) for bs in group_block_sizes] + decode_block_deltas = [ + d - p for d, p in zip(decode_block_counts, prefill_block_counts) + ] + max_blocks_per_req = sum(decode_block_counts) + num_reqs = min( model_runner.scheduler_config.max_num_seqs, model_runner.scheduler_config.max_num_batched_tokens // prompt_len, + # Reserve block 0 (null block) and ensure we have enough blocks. + max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req), ) - num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups) req_ids = [f"_warmup_{i}_" for i in range(num_reqs)] # SamplingParams exercising all sampling features. @@ -42,12 +58,18 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None: sampling_params = SamplingParams.for_sampler_warmup() pooling_params = None + # Assign distinct block IDs per request per group. 0 null block, start from 1. + next_block_id = 1 + + def _alloc_blocks(num_blocks: int) -> list[int]: + nonlocal next_block_id + return list(range(next_block_id, next_block_id := next_block_id + num_blocks)) + # Step 1: Prefill all requests with 2 prompt tokens each. new_reqs = [ NewRequestData.from_request( Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params), - # Each request uses a distinct block per KV cache group. - block_ids=tuple([i] for _ in range(num_kv_cache_groups)), + block_ids=tuple(_alloc_blocks(n) for n in prefill_block_counts), prefill_token_ids=prompt_token_ids, ) for i in range(num_reqs) @@ -84,9 +106,13 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None: # Step 2: Decode all requests with 1 token each. cached_req_data = CachedRequestData.make_empty() cached_req_data.req_ids = list(req_ids) - cached_req_data.new_block_ids = [None] * num_reqs cached_req_data.num_computed_tokens = [prompt_len] * num_reqs cached_req_data.num_output_tokens = [1] * num_reqs + new_block = any(decode_block_deltas) + cached_req_data.new_block_ids = [ + tuple(_alloc_blocks(n) for n in decode_block_deltas) if new_block else None + for _ in range(num_reqs) + ] decode_output = SchedulerOutput.make_empty() decode_output.scheduled_cached_reqs = cached_req_data diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4c11aede5..10e9f2f49 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -464,6 +464,10 @@ class Worker(WorkerBase): def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + # Update local config with adjusted num blocks after profiling, + # so that it's available to the warmup stage. + self.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + # Init kv cache connector here, because it requires # `kv_cache_config`. # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,