[BUGFIX][Mamba][Qwen3.5] Zero freed SSM cache blocks on GPU (#35219)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-03-10 14:32:20 +04:00
committed by GitHub
parent 507ddbe992
commit 4ff8c3c8f9
10 changed files with 287 additions and 8 deletions

View File

@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
) -> tuple[int, ...]:
raise NotImplementedError
@classmethod
def get_kv_cache_block_dim(
cls,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> int:
"""Discover which tensor dim is the block index, since different
backends lay out dims differently."""
_S = 1234567
shape = cls.get_kv_cache_shape(
_S,
block_size,
num_kv_heads,
head_size,
cache_dtype_str=cache_dtype_str,
)
return shape.index(_S)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,

View File

@@ -501,6 +501,13 @@ class KVCacheManager:
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return new attention block IDs for zeroing."""
ids: list[int] = []
for mgr in self.coordinator.single_type_managers:
ids.extend(mgr.take_new_block_ids())
return ids
def new_step_starts(self) -> None:
"""Called when a new step is started."""
self.coordinator.new_step_starts()

View File

@@ -233,6 +233,11 @@ class SchedulerOutput:
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None
# Block IDs freshly allocated from the pool during this scheduling step.
# The worker zeros the corresponding GPU memory before the blocks are used,
# preventing stale NaN/data from corrupting attention or SSM computation.
new_block_ids_to_zero: list[int] | None = None
@classmethod
def make_empty(cls) -> "SchedulerOutput":
return cls(

View File

@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
return any(
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
self.has_mamba_layers = kv_cache_config.has_mamba_layers
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
self.need_mamba_block_aligned_split = (
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
)
@@ -890,6 +885,12 @@ class Scheduler(SchedulerInterface):
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None)
if self.needs_kv_cache_zeroing
else None
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
@@ -905,6 +906,7 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
)
# NOTE(Kuntai): this function is designed for multiple purposes:

View File

@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
self.enable_caching = enable_caching
self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
)
req_blocks.extend(allocated_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_tokens_main_model: int
@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC):
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in new_blocks)
return new_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return block IDs allocated since the last call."""
ids = self.new_block_ids
self.new_block_ids = []
return ids
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.

View File

@@ -489,3 +489,11 @@ class KVCacheConfig:
For models with multiple types of attention, there will be multiple groups,
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
@property
def has_mamba_layers(self) -> bool:
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)
@property
def needs_kv_cache_zeroing(self) -> bool:
return self.has_mamba_layers

View File

@@ -197,6 +197,7 @@ from vllm.v1.worker.workspace import lock_workspace
from .utils import (
AttentionGroup,
KVBlockZeroer,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
prepare_kernel_block_sizes,
@@ -982,6 +983,26 @@ class GPUModelRunner(
decode_threshold=self.reorder_batch_threshold,
)
def _init_kv_zero_meta(self) -> None:
"""One-time precomputation for _zero_block_ids.
Delegates to KVBlockZeroer.init_meta with the runner's state.
Called from gpu_worker.py outside the CuMem pool context.
"""
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
self._kv_block_zeroer.init_meta(
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
kernel_block_sizes=self._kernel_block_sizes,
cache_dtype=self.cache_config.cache_dtype,
runner_only_attn_layers=self.runner_only_attn_layers,
static_forward_context=(self.compilation_config.static_forward_context),
)
def _zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if hasattr(self, "_kv_block_zeroer"):
self._kv_block_zeroer.zero_block_ids(block_ids)
# Note: used for model runner override.
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties"""
@@ -1018,6 +1039,11 @@ class GPUModelRunner(
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)
# Zero GPU memory for freshly allocated cache blocks to prevent
# stale NaN/data from corrupting attention or SSM computation.
if scheduler_output.new_block_ids_to_zero:
self._zero_block_ids(scheduler_output.new_block_ids_to_zero)
# Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
@@ -6476,6 +6502,7 @@ class GPUModelRunner(
kernel_block_sizes = prepare_kernel_block_sizes(
kv_cache_config, self.attn_groups
)
self._kernel_block_sizes = kernel_block_sizes
# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)

View File

@@ -556,6 +556,14 @@ class Worker(WorkerBase):
else:
self.model_runner.initialize_kv_cache(kv_cache_config)
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
# allocator and are not discarded during sleep/wake cycles.
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
self.model_runner, "_init_kv_zero_meta"
):
self.model_runner._init_kv_zero_meta()
@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float:
warmup_sizes: list[int] = []

View File

@@ -2,7 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import product as iprod
from typing import Any
import torch
@@ -12,6 +15,8 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import largest_power_of_2_divisor
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import (
AttentionBackend,
@@ -21,6 +26,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.kv_cache_interface import (
AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
@@ -31,6 +37,186 @@ from vllm.v1.kv_cache_interface import (
logger = init_logger(__name__)
@triton.jit
def _zero_kv_blocks_kernel(
seg_addrs_ptr,
block_ids_ptr,
n_blocks,
N_SEGS: tl.constexpr,
PAGE_SIZE_EL: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Zero KV cache blocks across all segments in a single launch.
Each segment is a contiguous region of one block's data. For backends
where blocks are outermost (block_dim=0) there is one segment per
buffer. For backends where K/V is outermost (block_dim=1) there are
two segments per buffer (one for K, one for V).
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
allowing segments to live in different CUDA allocations.
Programs are mapped as (block_index, seg_index, chunk_index).
"""
pid = tl.program_id(0)
chunks = PAGE_SIZE_EL // BLOCK_SIZE
work_per_block = N_SEGS * chunks
block_index = pid // work_per_block
if block_index >= n_blocks:
return
remainder = pid % work_per_block
seg_index = remainder // chunks
chunk_index = remainder % chunks
block_id = tl.load(block_ids_ptr + block_index)
seg_addr = tl.load(seg_addrs_ptr + seg_index)
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
offset = (
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
)
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
class KVBlockZeroer:
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
Call :meth:`init_meta` once after KV caches are allocated to precompute
segment addresses, then call :meth:`zero_block_ids` each step to zero
newly-allocated blocks.
"""
def __init__(self, device: torch.device, pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self._meta: tuple[torch.Tensor, int, int, int] | None = None
self._id_cap: int = 0
self._ids_pinned: torch.Tensor | None = None
self._ids_gpu: torch.Tensor | None = None
def init_meta(
self,
attn_groups_iter: Iterable["AttentionGroup"],
kernel_block_sizes: list[int],
cache_dtype: str,
runner_only_attn_layers: set[str],
static_forward_context: dict[str, Any],
) -> None:
"""One-time precomputation for zero_block_ids.
Builds absolute-address table for the Triton zeroing kernel.
Each entry is the absolute byte address of a segment start on the
GPU, so segments in different CUDA allocations work correctly.
Block IDs from the scheduler reference logical blocks whose size
may differ from the kernel block size (virtual block splitting).
PAGE_SIZE_EL accounts for this ratio so that
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
Only AttentionSpec layers are processed; Mamba layers are skipped.
"""
seen_ptrs: set[int] = set()
seg_addrs: list[int] = []
page_size_el: int | None = None
for group in attn_groups_iter:
spec = group.kv_cache_spec
if type(spec) is not FullAttentionSpec:
continue
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
ratio = spec.block_size // kernel_bs
block_dim = group.backend.get_kv_cache_block_dim(
kernel_bs,
spec.num_kv_heads,
spec.head_size,
cache_dtype_str=cache_dtype,
)
for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers:
continue
kv = static_forward_context[layer_name].kv_cache[0]
if isinstance(kv, list):
continue
dp = kv.data_ptr()
if dp in seen_ptrs:
continue
seen_ptrs.add(dp)
el = kv.element_size()
cur_bytes = kv.stride(block_dim) * el
assert cur_bytes % 4 == 0
kernel_block_el = cur_bytes // 4
cur_page_el = kernel_block_el * ratio
if page_size_el is None:
page_size_el = cur_page_el
else:
assert page_size_el == cur_page_el, (
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
)
block_stride_bytes = cur_bytes
outer_dims = [
d
for d in range(block_dim)
if kv.stride(d) * el > block_stride_bytes
]
outer_strides = [kv.stride(d) * el for d in outer_dims]
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
seg_addrs.append(dp + off_bytes)
if not seg_addrs or page_size_el is None:
self._meta = None
return
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
self._id_cap = 8192
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
self._meta = (
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
page_size_el,
blk_size,
len(seg_addrs),
)
def zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if not block_ids or self._meta is None:
return
seg_addrs, page_size_el, blk_size, n_segs = self._meta
n_blocks = len(block_ids)
if n_blocks > self._id_cap:
self._id_cap = n_blocks * 2
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(
self._id_cap, dtype=torch.int64, device=self.device
)
assert self._ids_pinned is not None and self._ids_gpu is not None
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
idx = self._ids_gpu[:n_blocks]
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
_zero_kv_blocks_kernel[grid](
seg_addrs,
idx,
n_blocks,
N_SEGS=n_segs,
PAGE_SIZE_EL=page_size_el,
BLOCK_SIZE=blk_size,
)
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]