diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py index d9554f6fb..80dbdf914 100644 --- a/tests/compile/passes/test_rope_kvcache_fusion.py +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -295,7 +295,7 @@ def test_rope_kvcache_fusion( } q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused) attn_layer = forward_context.no_compile_layers[model.layer_name] - kv_cache_unfused = attn_layer.kv_cache[forward_context.virtual_engine] + kv_cache_unfused = attn_layer.kv_cache[0] del dummy torch._dynamo.mark_dynamic(qkv, 0) @@ -309,7 +309,7 @@ def test_rope_kvcache_fusion( } q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos) attn_layer = forward_context.no_compile_layers[model.layer_name] - kv_cache_fused = attn_layer.kv_cache[forward_context.virtual_engine] + kv_cache_fused = attn_layer.kv_cache[0] del dummy assert fusion_pass.matched_count == 1 diff --git a/tests/v1/kv_connector/unit/test_decode_bench_connector.py b/tests/v1/kv_connector/unit/test_decode_bench_connector.py index 1d5343644..30652b3d5 100644 --- a/tests/v1/kv_connector/unit/test_decode_bench_connector.py +++ b/tests/v1/kv_connector/unit/test_decode_bench_connector.py @@ -86,7 +86,7 @@ class DecodeBenchTestRunner: self._block_hasher = get_request_block_hasher(block_size, sha256) self._dummy_ctx: ForwardContext = ForwardContext( - no_compile_layers={}, attn_metadata={}, virtual_engine=0, slot_mapping={} + no_compile_layers={}, attn_metadata={}, slot_mapping={} ) def new_request(self, token_ids: list[int]) -> Request: diff --git a/tests/v1/kv_connector/unit/test_lmcache_integration.py b/tests/v1/kv_connector/unit/test_lmcache_integration.py index 57ddaa8bf..5e08831a6 100644 --- a/tests/v1/kv_connector/unit/test_lmcache_integration.py +++ b/tests/v1/kv_connector/unit/test_lmcache_integration.py @@ -211,7 +211,6 @@ def test_forward_context_interface(): from vllm.forward_context import ForwardContext assumes(ForwardContext, "no_compile_layers", is_instance_of=dict) - assumes(ForwardContext, "virtual_engine") assumes(ForwardContext, "attn_metadata") diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 3da1b533a..674e09b4b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -599,7 +599,6 @@ class TestNixlHandshake: dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) _before_load = time.perf_counter() @@ -672,7 +671,6 @@ class TestNixlHandshake: dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) _before_load = time.perf_counter() @@ -908,7 +906,6 @@ class TestNixlHandshake: dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) _before_load = time.perf_counter() @@ -1079,7 +1076,6 @@ def test_kv_connector_stats(default_vllm_config, dist_init): dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) connector.start_load_kv(dummy_ctx) @@ -1890,7 +1886,6 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_ dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) connector.start_load_kv(dummy_ctx) @@ -2059,7 +2054,6 @@ def test_transfer_failure_logging( dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) @@ -2162,7 +2156,6 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) connector.start_load_kv(dummy_ctx) @@ -2215,7 +2208,6 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) dummy_ctx = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) connector.start_load_kv(dummy_ctx) diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index cf118f7f3..ba65f5bad 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -261,7 +261,6 @@ class RequestRunner: self._dummy_ctx: ForwardContext = ForwardContext( no_compile_layers={}, attn_metadata={}, - virtual_engine=0, slot_mapping={}, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py index 14feafced..0c5db695b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py @@ -185,7 +185,7 @@ class ExampleConnector(KVConnectorBase_V1): if kv_cache_attr is None: continue - kv_cache_layer = kv_cache_attr[forward_context.virtual_engine] + kv_cache_layer = kv_cache_attr[0] filename = self._generate_filename_debug( layer_name, request.token_ids, request.mm_hashes diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 4aacbddb8..f18c3c4e4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -778,9 +778,7 @@ class LMCacheConnectorV1Impl: continue if layer_name not in self.kv_caches: - self.kv_caches[layer_name] = attn_layer.kv_cache[ - forward_context.virtual_engine - ] + self.kv_caches[layer_name] = attn_layer.kv_cache[0] #################### # Worker side APIs diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 3be1be18e..24e82610c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -214,7 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1): if kv_cache is None: continue - layer = kv_cache[forward_context.virtual_engine] + layer = kv_cache[0] kv_cache = self.p2p_nccl_engine.recv_tensor( request.request_id + "#" + layer_name, remote_address diff --git a/vllm/forward_context.py b/vllm/forward_context.py index bf0f9da6e..a7aaeff4f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -197,8 +197,6 @@ class ForwardContext: for each microbatch. Set dynamically for each forward pass """ - # TODO: remove after making all virtual_engines share the same kv cache - virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: DPMetadata | None = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. @@ -265,7 +263,6 @@ def is_forward_context_available() -> bool: def create_forward_context( attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0, dp_metadata: DPMetadata | None = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, @@ -282,7 +279,6 @@ def create_forward_context( return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, all_moe_layers=all_moe_layers, - virtual_engine=virtual_engine, attn_metadata=attn_metadata, slot_mapping=slot_mapping or {}, dp_metadata=dp_metadata, @@ -313,7 +309,6 @@ def override_forward_context(forward_context: ForwardContext | None): def set_forward_context( attn_metadata: Any, vllm_config: VllmConfig, - virtual_engine: int = 0, num_tokens: int | None = None, num_tokens_across_dp: torch.Tensor | None = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, @@ -362,7 +357,6 @@ def set_forward_context( additional_kwargs = current_platform.set_additional_forward_context( attn_metadata=attn_metadata, vllm_config=vllm_config, - virtual_engine=virtual_engine, dp_metadata=dp_metadata, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, @@ -374,7 +368,6 @@ def set_forward_context( forward_context = create_forward_context( attn_metadata, vllm_config, - virtual_engine, dp_metadata, cudagraph_runtime_mode, batch_descriptor, diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 1ab22d408..5516cd329 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -589,7 +589,7 @@ def get_attention_context( - attn_metadata: Attention metadata for this specific layer, or None if no metadata available - attn_layer: The attention layer instance (Attention or MLAAttention) - - kv_cache: The KV cache tensor for current virtual engine + - kv_cache: The KV cache tensor for current forward pass - slot_mapping: The slot mapping for this specific layer Note: attn_metadata may be None, but attn_layer and kv_cache are always @@ -600,7 +600,7 @@ def get_attention_context( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] - kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + kv_cache = attn_layer.kv_cache[0] slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index b613f3ba9..9d2fa287d 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -480,7 +480,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( @@ -940,7 +940,7 @@ def unified_mla_kv_cache_update( return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) attn_layer = forward_context.no_compile_layers[layer_name] - kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + kv_cache = attn_layer.kv_cache[0] slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index 60419f967..3b25a2357 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -168,8 +168,7 @@ class StaticSinkAttention(Attention, CustomOp): "sink_key and sink_value have not been prepared" ) if not self.sink_populated: - forward_context: ForwardContext = get_forward_context() - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) return super().forward(query, key, value, output_shape) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index fde9ad36b..fddd807e0 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 num_actual_tokens = attn_metadata.num_actual_tokens - constant_caches = self.kv_cache[forward_context.virtual_engine] + constant_caches = self.kv_cache[0] q_proj_states = q_proj_states[:num_actual_tokens] k_proj_states = k_proj_states[:num_actual_tokens] diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 802141881..f90309050 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) if attn_metadata is not None: - kv_cache = self.kv_cache[forward_context.virtual_engine][0] + kv_cache = self.kv_cache[0][0] state_indices_tensor = attn_metadata.state_indices_tensor clear_linear_attention_cache_for_new_sequences( kv_cache, state_indices_tensor, attn_metadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6a33fc7d6..71baf2dae 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer): query_start_loc_p = attn_metadata.query_start_loc_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 971581d89..232afefd5 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] # conv_state = (..., dim, width-1) yet contiguous along 'dim' conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 2348af2d9..fbdf0d537 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, ShortConvAttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] conv_state = self_kv_cache[0].transpose(-1, -2) state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_d = attn_metadata.state_indices_tensor_d diff --git a/vllm/model_executor/models/bailing_moe_linear.py b/vllm/model_executor/models/bailing_moe_linear.py index 9b54ec634..8769e5197 100644 --- a/vllm/model_executor/models/bailing_moe_linear.py +++ b/vllm/model_executor/models/bailing_moe_linear.py @@ -709,7 +709,7 @@ class BailingMoELinearAttention(nn.Module, MambaBase): # Get KV cache and state indices if attn_metadata is not None: - kv_cache = self.kv_cache[forward_context.virtual_engine][0] + kv_cache = self.kv_cache[0][0] state_indices_tensor = attn_metadata.state_indices_tensor clear_linear_attention_cache_for_new_sequences( kv_cache, state_indices_tensor, attn_metadata diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py index ae9bdb5ed..bddaaadf5 100644 --- a/vllm/model_executor/models/extract_hidden_states.py +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -51,7 +51,7 @@ def unified_kv_cache_update( """ forward_context = get_forward_context() attn_layer = forward_context.no_compile_layers[layer_name] - kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + kv_cache = attn_layer.kv_cache[0] slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( diff --git a/vllm/model_executor/models/olmo_hybrid.py b/vllm/model_executor/models/olmo_hybrid.py index a94f8c875..bc932a51e 100644 --- a/vllm/model_executor/models/olmo_hybrid.py +++ b/vllm/model_executor/models/olmo_hybrid.py @@ -428,7 +428,7 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase): non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 81ba858d6..934ae8711 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -262,7 +262,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] # conv_state = (..., dim, width-1) yet contiguous along 'dim' conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 61c8a7ab1..10040bff0 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -842,7 +842,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): a=a, core_attn_out=core_attn_out, attn_metadata=attn_metadata, - virtual_engine=forward_context.virtual_engine, ) has_initial_state = attn_metadata.has_initial_state @@ -853,7 +852,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self_kv_cache = self.kv_cache[0] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens @@ -1036,13 +1035,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): a: torch.Tensor, core_attn_out: torch.Tensor, attn_metadata: GDNAttentionMetadata, - virtual_engine: int, ): """ Core attention computation with a packed non-spec decode fast path. """ non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[virtual_engine] + self_kv_cache = self.kv_cache[0] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 2606aada0..63261ca9a 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -510,7 +510,7 @@ def bind_kv_cache( # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): - # NOTE: Use list because of v0 PP virtual engine. + # NOTE: Keep list wrapper for layers that index kv_cache by engine slot. forward_context[layer_name].kv_cache = [kv_cache]