[V0 Deprecation] Refactor kv cache from list to element (#37487)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -349,10 +349,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
@@ -599,7 +596,7 @@ def get_attention_context(
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[0]
|
||||
kv_cache = attn_layer.kv_cache
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
|
||||
@@ -415,12 +415,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(
|
||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
||||
)
|
||||
]
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
@@ -479,7 +474,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
@@ -939,7 +934,7 @@ def unified_mla_kv_cache_update(
|
||||
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
|
||||
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[0]
|
||||
kv_cache = attn_layer.kv_cache
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
|
||||
@@ -168,7 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
|
||||
"sink_key and sink_value have not been prepared"
|
||||
)
|
||||
if not self.sink_populated:
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
|
||||
|
||||
return super().forward(query, key, value, output_shape)
|
||||
|
||||
@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
constant_caches = self.kv_cache[0]
|
||||
constant_caches = self.kv_cache
|
||||
|
||||
q_proj_states = q_proj_states[:num_actual_tokens]
|
||||
k_proj_states = k_proj_states[:num_actual_tokens]
|
||||
|
||||
@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[0][0]
|
||||
kv_cache = self.kv_cache[0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache, state_indices_tensor, attn_metadata
|
||||
|
||||
@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
|
||||
@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
|
||||
@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
|
||||
@@ -365,7 +365,7 @@ class SparseAttnIndexer(CustomOp):
|
||||
return torch.ops.vllm.sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
self.k_cache.kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
@@ -389,7 +389,7 @@ class SparseAttnIndexer(CustomOp):
|
||||
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
self.k_cache.kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
|
||||
@@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase):
|
||||
|
||||
# Get KV cache and state indices
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[0][0]
|
||||
kv_cache = self.kv_cache[0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache, state_indices_tensor, attn_metadata
|
||||
|
||||
@@ -586,7 +586,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
|
||||
):
|
||||
super().__init__()
|
||||
self.kv_cache = [torch.tensor([])]
|
||||
self.kv_cache = torch.tensor([])
|
||||
self.head_dim = head_dim
|
||||
self.prefix = prefix
|
||||
self.cache_config = cache_config
|
||||
|
||||
@@ -51,7 +51,7 @@ def unified_kv_cache_update(
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[0]
|
||||
kv_cache = attn_layer.kv_cache
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
@@ -288,10 +288,7 @@ class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
|
||||
# Placeholder KV cache (replaced by bind_kv_cache)
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
# Register in compilation context
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
@@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
|
||||
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
@@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
|
||||
@@ -858,7 +858,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
@@ -1046,7 +1046,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
Core attention computation with a packed non-spec decode fast path.
|
||||
"""
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
self_kv_cache = self.kv_cache
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
Reference in New Issue
Block a user