[Model Runner V2] Fix warmup for very small kvcache and/or blocksizes (#36176)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import (
|
||||
CachedRequestData,
|
||||
GrammarOutput,
|
||||
@@ -26,12 +27,27 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
||||
"""
|
||||
prompt_token_ids = [0, 1]
|
||||
prompt_len = len(prompt_token_ids)
|
||||
decode_len = prompt_len + 1 # After prefill, one decode token is added.
|
||||
|
||||
kv_cache_groups = model_runner.kv_cache_config.kv_cache_groups
|
||||
num_kv_cache_groups = len(kv_cache_groups)
|
||||
|
||||
# Compute per-request block counts for each KV cache group.
|
||||
group_block_sizes = [g.kv_cache_spec.block_size for g in kv_cache_groups]
|
||||
prefill_block_counts = [cdiv(prompt_len, bs) for bs in group_block_sizes]
|
||||
decode_block_counts = [cdiv(decode_len, bs) for bs in group_block_sizes]
|
||||
decode_block_deltas = [
|
||||
d - p for d, p in zip(decode_block_counts, prefill_block_counts)
|
||||
]
|
||||
max_blocks_per_req = sum(decode_block_counts)
|
||||
|
||||
num_reqs = min(
|
||||
model_runner.scheduler_config.max_num_seqs,
|
||||
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
|
||||
# Reserve block 0 (null block) and ensure we have enough blocks.
|
||||
max(1, (model_runner.kv_cache_config.num_blocks - 1) // max_blocks_per_req),
|
||||
)
|
||||
|
||||
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
|
||||
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
|
||||
|
||||
# SamplingParams exercising all sampling features.
|
||||
@@ -42,12 +58,18 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
||||
sampling_params = SamplingParams.for_sampler_warmup()
|
||||
pooling_params = None
|
||||
|
||||
# Assign distinct block IDs per request per group. 0 null block, start from 1.
|
||||
next_block_id = 1
|
||||
|
||||
def _alloc_blocks(num_blocks: int) -> list[int]:
|
||||
nonlocal next_block_id
|
||||
return list(range(next_block_id, next_block_id := next_block_id + num_blocks))
|
||||
|
||||
# Step 1: Prefill all requests with 2 prompt tokens each.
|
||||
new_reqs = [
|
||||
NewRequestData.from_request(
|
||||
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
|
||||
# Each request uses a distinct block per KV cache group.
|
||||
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
|
||||
block_ids=tuple(_alloc_blocks(n) for n in prefill_block_counts),
|
||||
prefill_token_ids=prompt_token_ids,
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
@@ -84,9 +106,13 @@ def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
||||
# Step 2: Decode all requests with 1 token each.
|
||||
cached_req_data = CachedRequestData.make_empty()
|
||||
cached_req_data.req_ids = list(req_ids)
|
||||
cached_req_data.new_block_ids = [None] * num_reqs
|
||||
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
|
||||
cached_req_data.num_output_tokens = [1] * num_reqs
|
||||
new_block = any(decode_block_deltas)
|
||||
cached_req_data.new_block_ids = [
|
||||
tuple(_alloc_blocks(n) for n in decode_block_deltas) if new_block else None
|
||||
for _ in range(num_reqs)
|
||||
]
|
||||
|
||||
decode_output = SchedulerOutput.make_empty()
|
||||
decode_output.scheduled_cached_reqs = cached_req_data
|
||||
|
||||
@@ -464,6 +464,10 @@ class Worker(WorkerBase):
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
|
||||
# Update local config with adjusted num blocks after profiling,
|
||||
# so that it's available to the warmup stage.
|
||||
self.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
|
||||
|
||||
# Init kv cache connector here, because it requires
|
||||
# `kv_cache_config`.
|
||||
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
|
||||
|
||||
Reference in New Issue
Block a user