[Model Runner V2] Fix warmup for very small kvcache and/or blocksizes (#36176)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-05 14:48:10 -08:00
committed by GitHub
parent a97954b6a8
commit a73af584fe
2 changed files with 34 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ import numpy as np
import torch import torch
from vllm import PoolingParams, SamplingParams from vllm import PoolingParams, SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import ( from vllm.v1.core.sched.output import (
CachedRequestData, CachedRequestData,
GrammarOutput, GrammarOutput,
@@ -26,12 +27,27 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
""" """
prompt_token_ids = [0, 1] prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
decode_len = prompt_len + 1 # After prefill, one decode token is added.
kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups
num_kv_cache_groups = len(kv_cache_groups)
# Compute per-request block counts for each KV cache group.
group_block_sizes = [g.kv_cache_spec.block_size for g in kv_cache_groups]
prefill_block_counts = [cdiv(prompt_len, bs) for bs in group_block_sizes]
decode_block_counts = [cdiv(decode_len, bs) for bs in group_block_sizes]
decode_block_deltas = [
d - p for d, p in zip(decode_block_counts, prefill_block_counts)
]
max_blocks_per_req = sum(decode_block_counts)
num_reqs = min( num_reqs = min(
model_runner.scheduler_config.max_num_seqs, model_runner.scheduler_config.max_num_seqs,
model_runner.scheduler_config.max_num_batched_tokens // prompt_len, model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
# Reserve block 0 (null block) and ensure we have enough blocks.
max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req),
) )
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)] req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
# SamplingParams exercising all sampling features. # SamplingParams exercising all sampling features.
@@ -42,12 +58,18 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
sampling_params = SamplingParams.for_sampler_warmup() sampling_params = SamplingParams.for_sampler_warmup()
pooling_params = None pooling_params = None
# Assign distinct block IDs per request per group. 0 null block, start from 1.
next_block_id = 1
def _alloc_blocks(num_blocks: int) -> list[int]:
nonlocal next_block_id
return list(range(next_block_id, next_block_id := next_block_id + num_blocks))
# Step 1: Prefill all requests with 2 prompt tokens each. # Step 1: Prefill all requests with 2 prompt tokens each.
new_reqs = [ new_reqs = [
NewRequestData.from_request( NewRequestData.from_request(
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params), Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
# Each request uses a distinct block per KV cache group. block_ids=tuple(_alloc_blocks(n) for n in prefill_block_counts),
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
prefill_token_ids=prompt_token_ids, prefill_token_ids=prompt_token_ids,
) )
for i in range(num_reqs) for i in range(num_reqs)
@@ -84,9 +106,13 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
# Step 2: Decode all requests with 1 token each. # Step 2: Decode all requests with 1 token each.
cached_req_data = CachedRequestData.make_empty() cached_req_data = CachedRequestData.make_empty()
cached_req_data.req_ids = list(req_ids) cached_req_data.req_ids = list(req_ids)
cached_req_data.new_block_ids = [None] * num_reqs
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
cached_req_data.num_output_tokens = [1] * num_reqs cached_req_data.num_output_tokens = [1] * num_reqs
new_block = any(decode_block_deltas)
cached_req_data.new_block_ids = [
tuple(_alloc_blocks(n) for n in decode_block_deltas) if new_block else None
for _ in range(num_reqs)
]
decode_output = SchedulerOutput.make_empty() decode_output = SchedulerOutput.make_empty()
decode_output.scheduled_cached_reqs = cached_req_data decode_output.scheduled_cached_reqs = cached_req_data

View File

@@ -464,6 +464,10 @@ class Worker(WorkerBase):
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
# Update local config with adjusted num blocks after profiling,
# so that it's available to the warmup stage.
self.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
# Init kv cache connector here, because it requires # Init kv cache connector here, because it requires
# `kv_cache_config`. # `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`, # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,