[V0 Deprecation] Refactor kv cache from list to element (#37487)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-23 23:10:11 -04:00
committed by GitHub
parent de99d91ece
commit c59a132f96
27 changed files with 70 additions and 85 deletions

View File

@@ -349,10 +349,7 @@ class Attention(nn.Module, AttentionLayerBase):
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
self.kv_cache = torch.tensor([])
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
@@ -599,7 +596,7 @@ def get_attention_context(
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "

View File

@@ -415,12 +415,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
self.kv_cache = torch.tensor([])
self.use_sparse = use_sparse
@@ -479,7 +474,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
@@ -939,7 +934,7 @@ def unified_mla_kv_cache_update(
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (

View File

@@ -168,7 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
"sink_key and sink_value have not been prepared"
)
if not self.sink_populated:
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
return super().forward(query, key, value, output_shape)

View File

@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens
constant_caches = self.kv_cache[0]
constant_caches = self.kv_cache
q_proj_states = q_proj_states[:num_actual_tokens]
k_proj_states = k_proj_states[:num_actual_tokens]

View File

@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None:
kv_cache = self.kv_cache[0][0]
kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata

View File

@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p

View File

@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]

View File

@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d

View File

@@ -365,7 +365,7 @@ class SparseAttnIndexer(CustomOp):
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
self.k_cache.kv_cache,
q_fp8,
k,
weights,
@@ -389,7 +389,7 @@ class SparseAttnIndexer(CustomOp):
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
self.k_cache.kv_cache,
q_fp8,
k,
weights,

View File

@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase):
# Get KV cache and state indices
if attn_metadata is not None:
kv_cache = self.kv_cache[0][0]
kv_cache = self.kv_cache[0]
state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata

View File

@@ -586,7 +586,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
):
super().__init__()
self.kv_cache = [torch.tensor([])]
self.kv_cache = torch.tensor([])
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config

View File

@@ -51,7 +51,7 @@ def unified_kv_cache_update(
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[0]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
@@ -288,10 +288,7 @@ class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
)
# Placeholder KV cache (replaced by bind_kv_cache)
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
self.kv_cache = torch.tensor([])
# Register in compilation context
compilation_config = vllm_config.compilation_config

View File

@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens

View File

@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]

View File

@@ -858,7 +858,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
@@ -1046,7 +1046,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[0]
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens