[BUGFIX][Mamba][Qwen3.5] Zero freed SSM cache blocks on GPU (#35219)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
|
||||
) -> tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_kv_cache_block_dim(
|
||||
cls,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> int:
|
||||
"""Discover which tensor dim is the block index, since different
|
||||
backends lay out dims differently."""
|
||||
_S = 1234567
|
||||
shape = cls.get_kv_cache_shape(
|
||||
_S,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str=cache_dtype_str,
|
||||
)
|
||||
return shape.index(_S)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
|
||||
@@ -501,6 +501,13 @@ class KVCacheManager:
|
||||
# Only create new KVCacheBlocks for non-empty blocks
|
||||
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|
||||
|
||||
def take_new_block_ids(self) -> list[int]:
|
||||
"""Drain and return new attention block IDs for zeroing."""
|
||||
ids: list[int] = []
|
||||
for mgr in self.coordinator.single_type_managers:
|
||||
ids.extend(mgr.take_new_block_ids())
|
||||
return ids
|
||||
|
||||
def new_step_starts(self) -> None:
|
||||
"""Called when a new step is started."""
|
||||
self.coordinator.new_step_starts()
|
||||
|
||||
@@ -233,6 +233,11 @@ class SchedulerOutput:
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
# Block IDs freshly allocated from the pool during this scheduling step.
|
||||
# The worker zeros the corresponding GPU memory before the blocks are used,
|
||||
# preventing stale NaN/data from corrupting attention or SSM computation.
|
||||
new_block_ids_to_zero: list[int] | None = None
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "SchedulerOutput":
|
||||
return cls(
|
||||
|
||||
@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
|
||||
return any(
|
||||
isinstance(group_spec.kv_cache_spec, MambaSpec)
|
||||
for group_spec in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
|
||||
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
|
||||
self.has_mamba_layers = kv_cache_config.has_mamba_layers
|
||||
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
|
||||
self.need_mamba_block_aligned_split = (
|
||||
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
|
||||
)
|
||||
@@ -890,6 +885,12 @@ class Scheduler(SchedulerInterface):
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||
|
||||
new_block_ids_to_zero = (
|
||||
(self.kv_cache_manager.take_new_block_ids() or None)
|
||||
if self.needs_kv_cache_zeroing
|
||||
else None
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
@@ -905,6 +906,7 @@ class Scheduler(SchedulerInterface):
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
new_block_ids_to_zero=new_block_ids_to_zero,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
|
||||
@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
self.enable_caching = enable_caching
|
||||
self.new_block_ids: list[int] = []
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC):
|
||||
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
|
||||
)
|
||||
req_blocks.extend(allocated_blocks)
|
||||
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int, num_tokens_main_model: int
|
||||
@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC):
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||
self.new_block_ids.extend(b.block_id for b in new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def take_new_block_ids(self) -> list[int]:
|
||||
"""Drain and return block IDs allocated since the last call."""
|
||||
ids = self.new_block_ids
|
||||
self.new_block_ids = []
|
||||
return ids
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
@@ -489,3 +489,11 @@ class KVCacheConfig:
|
||||
For models with multiple types of attention, there will be multiple groups,
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_mamba_layers(self) -> bool:
|
||||
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)
|
||||
|
||||
@property
|
||||
def needs_kv_cache_zeroing(self) -> bool:
|
||||
return self.has_mamba_layers
|
||||
|
||||
@@ -197,6 +197,7 @@ from vllm.v1.worker.workspace import lock_workspace
|
||||
|
||||
from .utils import (
|
||||
AttentionGroup,
|
||||
KVBlockZeroer,
|
||||
add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache,
|
||||
prepare_kernel_block_sizes,
|
||||
@@ -982,6 +983,26 @@ class GPUModelRunner(
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
)
|
||||
|
||||
def _init_kv_zero_meta(self) -> None:
|
||||
"""One-time precomputation for _zero_block_ids.
|
||||
|
||||
Delegates to KVBlockZeroer.init_meta with the runner's state.
|
||||
Called from gpu_worker.py outside the CuMem pool context.
|
||||
"""
|
||||
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
|
||||
self._kv_block_zeroer.init_meta(
|
||||
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
|
||||
kernel_block_sizes=self._kernel_block_sizes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
runner_only_attn_layers=self.runner_only_attn_layers,
|
||||
static_forward_context=(self.compilation_config.static_forward_context),
|
||||
)
|
||||
|
||||
def _zero_block_ids(self, block_ids: list[int]) -> None:
|
||||
"""Zero the KV cache memory for the given block IDs."""
|
||||
if hasattr(self, "_kv_block_zeroer"):
|
||||
self._kv_block_zeroer.zero_block_ids(block_ids)
|
||||
|
||||
# Note: used for model runner override.
|
||||
def _init_device_properties(self) -> None:
|
||||
"""Initialize attributes from torch.cuda.get_device_properties"""
|
||||
@@ -1018,6 +1039,11 @@ class GPUModelRunner(
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.input_batch.remove_request(req_id)
|
||||
|
||||
# Zero GPU memory for freshly allocated cache blocks to prevent
|
||||
# stale NaN/data from corrupting attention or SSM computation.
|
||||
if scheduler_output.new_block_ids_to_zero:
|
||||
self._zero_block_ids(scheduler_output.new_block_ids_to_zero)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
@@ -6476,6 +6502,7 @@ class GPUModelRunner(
|
||||
kernel_block_sizes = prepare_kernel_block_sizes(
|
||||
kv_cache_config, self.attn_groups
|
||||
)
|
||||
self._kernel_block_sizes = kernel_block_sizes
|
||||
|
||||
# create metadata builders
|
||||
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
|
||||
|
||||
@@ -556,6 +556,14 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
|
||||
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
|
||||
# allocator and are not discarded during sleep/wake cycles.
|
||||
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
|
||||
self.model_runner, "_init_kv_zero_meta"
|
||||
):
|
||||
self.model_runner._init_kv_zero_meta()
|
||||
|
||||
@instrument(span_name="Warmup (GPU)")
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
warmup_sizes: list[int] = []
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import product as iprod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,6 +15,8 @@ from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import largest_power_of_2_divisor
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -21,6 +26,7 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
@@ -31,6 +37,186 @@ from vllm.v1.kv_cache_interface import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _zero_kv_blocks_kernel(
|
||||
seg_addrs_ptr,
|
||||
block_ids_ptr,
|
||||
n_blocks,
|
||||
N_SEGS: tl.constexpr,
|
||||
PAGE_SIZE_EL: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Zero KV cache blocks across all segments in a single launch.
|
||||
|
||||
Each segment is a contiguous region of one block's data. For backends
|
||||
where blocks are outermost (block_dim=0) there is one segment per
|
||||
buffer. For backends where K/V is outermost (block_dim=1) there are
|
||||
two segments per buffer (one for K, one for V).
|
||||
|
||||
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
|
||||
allowing segments to live in different CUDA allocations.
|
||||
|
||||
Programs are mapped as (block_index, seg_index, chunk_index).
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
chunks = PAGE_SIZE_EL // BLOCK_SIZE
|
||||
work_per_block = N_SEGS * chunks
|
||||
block_index = pid // work_per_block
|
||||
if block_index >= n_blocks:
|
||||
return
|
||||
remainder = pid % work_per_block
|
||||
seg_index = remainder // chunks
|
||||
chunk_index = remainder % chunks
|
||||
block_id = tl.load(block_ids_ptr + block_index)
|
||||
seg_addr = tl.load(seg_addrs_ptr + seg_index)
|
||||
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
|
||||
offset = (
|
||||
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
|
||||
)
|
||||
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
|
||||
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
|
||||
|
||||
|
||||
class KVBlockZeroer:
|
||||
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
|
||||
|
||||
Call :meth:`init_meta` once after KV caches are allocated to precompute
|
||||
segment addresses, then call :meth:`zero_block_ids` each step to zero
|
||||
newly-allocated blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self._meta: tuple[torch.Tensor, int, int, int] | None = None
|
||||
self._id_cap: int = 0
|
||||
self._ids_pinned: torch.Tensor | None = None
|
||||
self._ids_gpu: torch.Tensor | None = None
|
||||
|
||||
def init_meta(
|
||||
self,
|
||||
attn_groups_iter: Iterable["AttentionGroup"],
|
||||
kernel_block_sizes: list[int],
|
||||
cache_dtype: str,
|
||||
runner_only_attn_layers: set[str],
|
||||
static_forward_context: dict[str, Any],
|
||||
) -> None:
|
||||
"""One-time precomputation for zero_block_ids.
|
||||
|
||||
Builds absolute-address table for the Triton zeroing kernel.
|
||||
Each entry is the absolute byte address of a segment start on the
|
||||
GPU, so segments in different CUDA allocations work correctly.
|
||||
|
||||
Block IDs from the scheduler reference logical blocks whose size
|
||||
may differ from the kernel block size (virtual block splitting).
|
||||
PAGE_SIZE_EL accounts for this ratio so that
|
||||
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
|
||||
|
||||
Only AttentionSpec layers are processed; Mamba layers are skipped.
|
||||
"""
|
||||
seen_ptrs: set[int] = set()
|
||||
seg_addrs: list[int] = []
|
||||
page_size_el: int | None = None
|
||||
|
||||
for group in attn_groups_iter:
|
||||
spec = group.kv_cache_spec
|
||||
if type(spec) is not FullAttentionSpec:
|
||||
continue
|
||||
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||
continue
|
||||
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
|
||||
ratio = spec.block_size // kernel_bs
|
||||
block_dim = group.backend.get_kv_cache_block_dim(
|
||||
kernel_bs,
|
||||
spec.num_kv_heads,
|
||||
spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner_only_attn_layers:
|
||||
continue
|
||||
kv = static_forward_context[layer_name].kv_cache[0]
|
||||
if isinstance(kv, list):
|
||||
continue
|
||||
dp = kv.data_ptr()
|
||||
if dp in seen_ptrs:
|
||||
continue
|
||||
seen_ptrs.add(dp)
|
||||
|
||||
el = kv.element_size()
|
||||
cur_bytes = kv.stride(block_dim) * el
|
||||
assert cur_bytes % 4 == 0
|
||||
kernel_block_el = cur_bytes // 4
|
||||
cur_page_el = kernel_block_el * ratio
|
||||
if page_size_el is None:
|
||||
page_size_el = cur_page_el
|
||||
else:
|
||||
assert page_size_el == cur_page_el, (
|
||||
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
|
||||
)
|
||||
|
||||
block_stride_bytes = cur_bytes
|
||||
outer_dims = [
|
||||
d
|
||||
for d in range(block_dim)
|
||||
if kv.stride(d) * el > block_stride_bytes
|
||||
]
|
||||
outer_strides = [kv.stride(d) * el for d in outer_dims]
|
||||
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
|
||||
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
|
||||
seg_addrs.append(dp + off_bytes)
|
||||
|
||||
if not seg_addrs or page_size_el is None:
|
||||
self._meta = None
|
||||
return
|
||||
|
||||
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
|
||||
self._id_cap = 8192
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
|
||||
self._meta = (
|
||||
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
|
||||
page_size_el,
|
||||
blk_size,
|
||||
len(seg_addrs),
|
||||
)
|
||||
|
||||
def zero_block_ids(self, block_ids: list[int]) -> None:
|
||||
"""Zero the KV cache memory for the given block IDs."""
|
||||
if not block_ids or self._meta is None:
|
||||
return
|
||||
seg_addrs, page_size_el, blk_size, n_segs = self._meta
|
||||
n_blocks = len(block_ids)
|
||||
if n_blocks > self._id_cap:
|
||||
self._id_cap = n_blocks * 2
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(
|
||||
self._id_cap, dtype=torch.int64, device=self.device
|
||||
)
|
||||
assert self._ids_pinned is not None and self._ids_gpu is not None
|
||||
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
|
||||
idx = self._ids_gpu[:n_blocks]
|
||||
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
|
||||
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
|
||||
_zero_kv_blocks_kernel[grid](
|
||||
seg_addrs,
|
||||
idx,
|
||||
n_blocks,
|
||||
N_SEGS=n_segs,
|
||||
PAGE_SIZE_EL=page_size_el,
|
||||
BLOCK_SIZE=blk_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
|
||||
Reference in New Issue
Block a user