diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py index ffa01563e..5c2d03213 100644 --- a/tests/compile/passes/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -127,7 +127,7 @@ class AttentionQuantPatternModel(torch.nn.Module): raw_tensor = raw_tensor.view(kv_cache_shape) kv_cache = raw_tensor.permute(*inv_order) - self.attn.kv_cache = [kv_cache] + self.attn.kv_cache = kv_cache # Build attn metadata self.attn_metadata = self.builder.build( diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py index 80dbdf914..eea21c917 100644 --- a/tests/compile/passes/test_rope_kvcache_fusion.py +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -148,7 +148,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module): raw_tensor = raw_tensor.view(kv_cache_shape) kv_cache = raw_tensor.permute(*inv_order) - self.attn.kv_cache = [kv_cache] + self.attn.kv_cache = kv_cache # Build attn metadata attn_metadata = self.builder.build( @@ -295,7 +295,7 @@ def test_rope_kvcache_fusion( } q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused) attn_layer = forward_context.no_compile_layers[model.layer_name] - kv_cache_unfused = attn_layer.kv_cache[0] + kv_cache_unfused = attn_layer.kv_cache del dummy torch._dynamo.mark_dynamic(qkv, 0) @@ -309,7 +309,7 @@ def test_rope_kvcache_fusion( } q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos) attn_layer = forward_context.no_compile_layers[model.layer_name] - kv_cache_fused = attn_layer.kv_cache[0] + kv_cache_fused = attn_layer.kv_cache del dummy assert fusion_pass.matched_count == 1 diff --git a/tests/v1/e2e/general/test_mamba_prefix_cache.py b/tests/v1/e2e/general/test_mamba_prefix_cache.py index d69088772..747c5defe 100644 --- a/tests/v1/e2e/general/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/general/test_mamba_prefix_cache.py @@ -258,8 +258,8 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable): mamba_kv_cache_dict[ num_computed_tokens - num_computed_tokens % BLOCK_SIZE ] = ( - kv_cache[0][0][block_id].clone(), - kv_cache[0][1][block_id].clone(), + kv_cache[0][block_id].clone(), + kv_cache[1][block_id].clone(), ) last_num_computed_tokens = num_computed_tokens @@ -302,7 +302,7 @@ def get_fake_process_mamba_fn( mamba_layer_name = kv_cache_config.kv_cache_groups[ mamba_group_id ].layer_names[0] - mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1] + mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[-1] mamba_block_table = input_batch.block_table.block_tables[ mamba_group_id ].block_table.cpu[0] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index dd23d9dfa..93c5435e8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -670,8 +670,8 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config): runner.initialize_kv_cache(kv_cache_config) - layer_0_kv = vllm_ctx[layer_0].kv_cache[0] - layer_1_kv = vllm_ctx[layer_1].kv_cache[0] + layer_0_kv = vllm_ctx[layer_0].kv_cache + layer_1_kv = vllm_ctx[layer_1].kv_cache # check layer 1 kv cache does NOT share memory with layer 0 assert id(layer_1_kv) != id(layer_0_kv) @@ -740,8 +740,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config): runner.initialize_kv_cache(kv_cache_config) kv_cache_config_after_init = runner.kv_cache_config - layer_0_kv = vllm_ctx[layer_0].kv_cache[0] - layer_1_kv = vllm_ctx[layer_1].kv_cache[0] + layer_0_kv = vllm_ctx[layer_0].kv_cache + layer_1_kv = vllm_ctx[layer_1].kv_cache # check layer 1 kv cache shares memory with layer 0 assert id(layer_1_kv) == id(layer_0_kv) @@ -864,9 +864,9 @@ def test_hybrid_attention_mamba_tensor_shapes(): np.random.shuffle(ind) blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] - attn_shape = vllm_ctx[layer_0].kv_cache[0].shape - conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape - ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape + attn_shape = vllm_ctx[layer_0].kv_cache.shape + conv_shape = vllm_ctx[layer_2].kv_cache[0].shape + ssm_shape = vllm_ctx[layer_2].kv_cache[1].shape # assert we are using FlashInfer assert attn_shape[0] % num_blocks == 0 @@ -905,21 +905,21 @@ def test_hybrid_attention_mamba_tensor_shapes(): kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio for layer in [layer_0, layer_1]: - # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] + # attention: kv_cache[kernel_block_idx, kv_idx, ...] for i, kernel_block in enumerate(kernel_blocks_for_attention): - vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] + vllm_ctx[layer].kv_cache[kernel_block, :] = attn_blocks_constant[i] # fill mamba blocks with constants using kernel block indices for layer in [layer_2, layer_3, layer_4, layer_5]: - # mamba: kv_cache[0][component][kernel_block_idx, ...] + # mamba: kv_cache[component][kernel_block_idx, ...] for i, kv_block in enumerate(kv_blocks_for_mamba): - vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] - vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] + vllm_ctx[layer].kv_cache[0][kv_block, :] = conv_blocks_constant[i] + vllm_ctx[layer].kv_cache[1][kv_block, :] = ssm_blocks_constant[i] # verify attention and mamba contents are correct for layer in [layer_0, layer_1]: for i, kernel_block in enumerate(kernel_blocks_for_attention): - actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] + actual_kv = vllm_ctx[layer].kv_cache[kernel_block, :] expected = attn_blocks_constant[i] # Check K and V separately @@ -928,8 +928,8 @@ def test_hybrid_attention_mamba_tensor_shapes(): for layer in [layer_2, layer_3, layer_4, layer_5]: for i, kv_block in enumerate(kv_blocks_for_mamba): - actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] - actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :] expected_conv = conv_blocks_constant[i] expected_ssm = ssm_blocks_constant[i] @@ -938,8 +938,8 @@ def test_hybrid_attention_mamba_tensor_shapes(): for layer in [layer_2, layer_3, layer_4, layer_5]: for i, kv_block in enumerate(kv_blocks_for_mamba): - actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] - actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :] expected_conv = conv_blocks_constant[i] expected_ssm = ssm_blocks_constant[i] assert torch.equal(actual_conv, expected_conv) diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py index 76f9a8f90..016b8aa5a 100644 --- a/tests/v1/worker/test_utils.py +++ b/tests/v1/worker/test_utils.py @@ -23,10 +23,10 @@ def test_bind_kv_cache(default_vllm_config): } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"] - assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"] - assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"] - assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"] + assert ctx["layers.0.self_attn"].kv_cache is kv_cache["layers.0.self_attn"] + assert ctx["layers.1.self_attn"].kv_cache is kv_cache["layers.1.self_attn"] + assert ctx["layers.2.self_attn"].kv_cache is kv_cache["layers.2.self_attn"] + assert ctx["layers.3.self_attn"].kv_cache is kv_cache["layers.3.self_attn"] assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"] assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"] @@ -50,8 +50,8 @@ def test_bind_kv_cache_non_attention(default_vllm_config): runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"] - assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"] + assert ctx["model.layers.20.attn"].kv_cache is kv_cache["model.layers.20.attn"] + assert ctx["model.layers.28.attn"].kv_cache is kv_cache["model.layers.28.attn"] assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] @@ -74,14 +74,14 @@ def test_bind_kv_cache_draft_model(default_vllm_config): runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] - assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] + assert ctx["model.layers.0.attn"].kv_cache is kv_cache["model.layers.0.attn"] + assert ctx["model.layers.1.attn"].kv_cache is kv_cache["model.layers.1.attn"] assert ( - ctx["draft_model.layers.0.attn"].kv_cache[0] + ctx["draft_model.layers.0.attn"].kv_cache is kv_cache["draft_model.layers.0.attn"] ) assert ( - ctx["draft_model.layers.1.attn"].kv_cache[0] + ctx["draft_model.layers.1.attn"].kv_cache is kv_cache["draft_model.layers.1.attn"] ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py index 0c5db695b..24e156561 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py @@ -181,12 +181,10 @@ class ExampleConnector(KVConnectorBase_V1): # Only process layers that have kv_cache # attribute (attention layers) Skip non-attention # layers like FusedMoE/MLP etc. - kv_cache_attr = getattr(layer, "kv_cache", None) - if kv_cache_attr is None: + kv_cache_layer = getattr(layer, "kv_cache", None) + if kv_cache_layer is None: continue - kv_cache_layer = kv_cache_attr[0] - filename = self._generate_filename_debug( layer_name, request.token_ids, request.mm_hashes ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index f18c3c4e4..35cd70606 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -778,7 +778,7 @@ class LMCacheConnectorV1Impl: continue if layer_name not in self.kv_caches: - self.kv_caches[layer_name] = attn_layer.kv_cache[0] + self.kv_caches[layer_name] = attn_layer.kv_cache #################### # Worker side APIs diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 24e82610c..ce228b3c6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1): if kv_cache is None: continue - layer = kv_cache[0] + layer = kv_cache kv_cache = self.p2p_nccl_engine.recv_tensor( request.request_id + "#" + layer_name, remote_address diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 9b5842594..cc143fad3 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -349,10 +349,7 @@ class Attention(nn.Module, AttentionLayerBase): # use a placeholder kv cache tensor during init, which will be replaced # by bind_kv_cache # this variable will not be accessed if use_direct_call is True - self.kv_cache = [ - torch.tensor([]) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + self.kv_cache = torch.tensor([]) # Initialize KV cache quantization attributes _init_kv_cache_quant(self, quant_config, prefix) @@ -599,7 +596,7 @@ def get_attention_context( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] - kv_cache = attn_layer.kv_cache[0] + kv_cache = attn_layer.kv_cache slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index dcad30a8b..a51a8c2a7 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -415,12 +415,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - self.kv_cache = [ - torch.tensor([]) - for _ in range( - get_current_vllm_config().parallel_config.pipeline_parallel_size - ) - ] + self.kv_cache = torch.tensor([]) self.use_sparse = use_sparse @@ -479,7 +474,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( @@ -939,7 +934,7 @@ def unified_mla_kv_cache_update( return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) attn_layer = forward_context.no_compile_layers[layer_name] - kv_cache = attn_layer.kv_cache[0] + kv_cache = attn_layer.kv_cache slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index 3b25a2357..913d73a16 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -168,7 +168,7 @@ class StaticSinkAttention(Attention, CustomOp): "sink_key and sink_value have not been prepared" ) if not self.sink_populated: - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) return super().forward(query, key, value, output_shape) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index fddd807e0..46db5dc32 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 num_actual_tokens = attn_metadata.num_actual_tokens - constant_caches = self.kv_cache[0] + constant_caches = self.kv_cache q_proj_states = q_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens] diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index f90309050..18fcc1426 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) if attn_metadata is not None: - kv_cache = self.kv_cache[0][0] + kv_cache = self.kv_cache[0] state_indices_tensor = attn_metadata.state_indices_tensor clear_linear_attention_cache_for_new_sequences( kv_cache, state_indices_tensor, attn_metadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 71baf2dae..82ca367fb 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer): query_start_loc_p = attn_metadata.query_start_loc_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 232afefd5..9486e182e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache # conv_state = (..., dim, width-1) yet contiguous along 'dim' conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index fbdf0d537..d36dc0096 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, ShortConvAttentionMetadata) - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache conv_state = self_kv_cache[0].transpose(-1, -2) state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 0d55ba858..c34800247 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -365,7 +365,7 @@ class SparseAttnIndexer(CustomOp): return torch.ops.vllm.sparse_attn_indexer( hidden_states, self.k_cache.prefix, - self.k_cache.kv_cache[0], + self.k_cache.kv_cache, q_fp8, k, weights, @@ -389,7 +389,7 @@ class SparseAttnIndexer(CustomOp): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, self.k_cache.prefix, - self.k_cache.kv_cache[0], + self.k_cache.kv_cache, q_fp8, k, weights, diff --git a/vllm/model_executor/models/bailing_moe_linear.py b/vllm/model_executor/models/bailing_moe_linear.py index 8769e5197..ecc5d63ce 100644 --- a/vllm/model_executor/models/bailing_moe_linear.py +++ b/vllm/model_executor/models/bailing_moe_linear.py @@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase): # Get KV cache and state indices if attn_metadata is not None: - kv_cache = self.kv_cache[0][0] + kv_cache = self.kv_cache[0] state_indices_tensor = attn_metadata.state_indices_tensor clear_linear_attention_cache_for_new_sequences( kv_cache, state_indices_tensor, attn_metadata diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f31e9ac3e..f1c4a7b21 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -586,7 +586,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig ): super().__init__() - self.kv_cache = [torch.tensor([])] + self.kv_cache = torch.tensor([]) self.head_dim = head_dim self.prefix = prefix self.cache_config = cache_config diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py index bddaaadf5..d969441ac 100644 --- a/vllm/model_executor/models/extract_hidden_states.py +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -51,7 +51,7 @@ def unified_kv_cache_update( """ forward_context = get_forward_context() attn_layer = forward_context.no_compile_layers[layer_name] - kv_cache = attn_layer.kv_cache[0] + kv_cache = attn_layer.kv_cache slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( @@ -288,10 +288,7 @@ class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase): ) # Placeholder KV cache (replaced by bind_kv_cache) - self.kv_cache = [ - torch.tensor([]) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + self.kv_cache = torch.tensor([]) # Register in compilation context compilation_config = vllm_config.compilation_config diff --git a/vllm/model_executor/models/olmo_hybrid.py b/vllm/model_executor/models/olmo_hybrid.py index bc932a51e..97e56b3ff 100644 --- a/vllm/model_executor/models/olmo_hybrid.py +++ b/vllm/model_executor/models/olmo_hybrid.py @@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase): non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 934ae8711..ffb86a8a9 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache # conv_state = (..., dim, width-1) yet contiguous along 'dim' conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 5adccf178..5dfcd677b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -858,7 +858,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens @@ -1046,7 +1046,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): Core attention computation with a packed non-spec decode fast path. """ non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[0] + self_kv_cache = self.kv_cache conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 22f524cd9..6834918b8 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -481,13 +481,9 @@ class AiterFlashAttentionMetadataBuilder( ): layers = get_layers_from_vllm_config(self.vllm_config, Attention) first_layer_name = [k for k in layers][0] - kv_cache_shape = ( - self.vllm_config.compilation_config.static_forward_context[ - first_layer_name - ] - .kv_cache[0] - .shape - ) + kv_cache_shape = self.vllm_config.compilation_config.static_forward_context[ + first_layer_name + ].kv_cache.shape num_blocks = kv_cache_shape[1] self.scale = torch.ones( [num_blocks, self.num_heads_kv, self.block_size], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08fd27573..7c6248b37 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5830,7 +5830,10 @@ class GPUModelRunner( for layer in self.compilation_config.static_forward_context.values(): if hasattr(layer, "kv_cache"): - layer.kv_cache = [] + kv_cache = layer.kv_cache + layer.kv_cache = ( + torch.tensor([]) if isinstance(kv_cache, torch.Tensor) else [] + ) gc.collect() torch.accelerator.empty_cache() diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index ed618e099..c832389b1 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -119,7 +119,7 @@ def collect_mamba_copy_meta( layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names for layer_name in layer_names: attention = forward_context[layer_name] - kv_caches: list[torch.Tensor] = attention.kv_cache[0] + kv_caches: list[torch.Tensor] = attention.kv_cache for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs): copy_spec = state_copy_func( state, block_ids, src_block_idx, accept_token_bias + 1 diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 63261ca9a..83fc12cb5 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -136,8 +136,8 @@ class KVBlockZeroer: for layer_name in group.layer_names: if layer_name in runner_only_attn_layers: continue - kv = static_forward_context[layer_name].kv_cache[0] - if isinstance(kv, list): + kv = static_forward_context[layer_name].kv_cache + if not isinstance(kv, torch.Tensor): continue dp = kv.data_ptr() if dp in seen_ptrs: @@ -510,8 +510,7 @@ def bind_kv_cache( # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): - # NOTE: Keep list wrapper for layers that index kv_cache by engine slot. - forward_context[layer_name].kv_cache = [kv_cache] + forward_context[layer_name].kv_cache = kv_cache def is_residual_scattered_for_sp(