diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6e2bb44e0..b6d918b41 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -13,6 +13,7 @@ from tests.v1.attention.utils import ( create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, + try_backend_includes_kv_cache_update, try_get_attention_backend, ) from vllm.config import ModelConfig @@ -295,6 +296,10 @@ def run_attention_backend( # Run forward pass # NOTE: The query, key, and value are already shaped correctly # in the calling test function. + if not try_backend_includes_kv_cache_update(actual_backend): + impl.do_kv_cache_update( + mock_layer, key, value, kv_cache, attn_metadata.slot_mapping + ) output = impl.forward( mock_layer, query, key, value, kv_cache, attn_metadata, output=output ) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index da4cea8fc..3cff52929 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -130,6 +130,18 @@ def try_get_attention_backend( raise AssertionError("unreachable") from None +def try_backend_includes_kv_cache_update( + backend: AttentionBackendEnum, +) -> bool: + """Try to get the attention backend class, skipping test if not found.""" + try: + backend_class = backend.get_class() + return backend_class.forward_includes_kv_cache_update + except ImportError as e: + pytest.skip(f"{backend.name} not available: {e}") + raise AssertionError("unreachable") from None + + def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: """Create a FullAttentionSpec from ModelParams only.""" return FullAttentionSpec( diff --git a/tests/v1/kv_connector/unit/test_decode_bench_connector.py b/tests/v1/kv_connector/unit/test_decode_bench_connector.py index 24802317a..93f4f8537 100644 --- a/tests/v1/kv_connector/unit/test_decode_bench_connector.py +++ b/tests/v1/kv_connector/unit/test_decode_bench_connector.py @@ -86,7 +86,7 @@ class DecodeBenchTestRunner: self._block_hasher = get_request_block_hasher(block_size, sha256) self._dummy_ctx: ForwardContext = ForwardContext( - no_compile_layers={}, attn_metadata={}, virtual_engine=0 + no_compile_layers={}, attn_metadata={}, virtual_engine=0, slot_mapping={} ) def new_request(self, token_ids: list[int]) -> Request: diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 1286af75d..e93835598 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -548,6 +548,7 @@ class TestNixlHandshake: no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) @@ -618,6 +619,7 @@ class TestNixlHandshake: no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) @@ -844,6 +846,7 @@ class TestNixlHandshake: no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) @@ -1006,6 +1009,7 @@ def test_kv_connector_stats(default_vllm_config, dist_init): no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) connector.start_load_kv(dummy_ctx) @@ -1767,6 +1771,7 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_ no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) connector.start_load_kv(dummy_ctx) @@ -1917,6 +1922,7 @@ def test_transfer_failure_logging( no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) # Capture logs from the nixl_connector logger specifically @@ -2017,6 +2023,7 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) connector.start_load_kv(dummy_ctx) @@ -2067,6 +2074,7 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) no_compile_layers={}, attn_metadata={}, virtual_engine=0, + slot_mapping={}, ) connector.start_load_kv(dummy_ctx) diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 0c8a185a9..fea9ff09b 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -209,7 +209,10 @@ class RequestRunner: self._block_hasher = get_request_block_hasher(gpu_block_size, sha256) self._dummy_ctx: ForwardContext = ForwardContext( - no_compile_layers={}, attn_metadata={}, virtual_engine=0 + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + slot_mapping={}, ) def new_request(self, token_ids: list[int]): diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index b5ce37ea4..bd7005540 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -9,6 +9,7 @@ import torch from tests.v1.attention.utils import ( create_standard_kv_cache_spec, create_vllm_config, + try_backend_includes_kv_cache_update, try_get_attention_backend, ) from vllm.config import ParallelConfig, SpeculativeConfig @@ -120,6 +121,14 @@ def forward_attention( key = k.view(-1, num_kv_heads, dim_per_head) value = v.view(-1, num_kv_heads, dim_per_head) output = torch.empty_like(query) + if not try_backend_includes_kv_cache_update(backend): + instance.do_kv_cache_update( + layer=layer, + key=key, + value=value, + kv_cache=kv_cache, + slot_mapping=attn_metadata.slot_mapping, + ) return instance.forward( layer=layer, query=query, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a1d796d50..9a6945f7a 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -390,29 +390,43 @@ class Attention(nn.Module, AttentionLayerBase): if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size_v) if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward( - self, query, key, value, self_kv_cache, attn_metadata, output=output + kv_cache_dummy_dep = None + if not self.attn_backend.forward_includes_kv_cache_update: + kv_cache_dummy_dep = unified_kv_cache_update( + key, value, self.layer_name + ) + unified_attention_with_output( + query, + key, + value, + output, + self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) else: + kv_cache_dummy_dep = None + if not self.attn_backend.forward_includes_kv_cache_update and ( + # torch can only dispatch custom op if a tensor is passed + key is not None or value is not None + ): + kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( + key, value, self.layer_name + ) torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name + query, + key, + value, + output, + self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) return output.view(-1, hidden_size) else: + assert self.attn_backend.forward_includes_kv_cache_update, ( + "Split KV cache update not supported when output tensor not provided." + ) if self.use_direct_call: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward( - self, query, key, value, self_kv_cache, attn_metadata - ) + return unified_attention(query, key, value, self.layer_name) else: return torch.ops.vllm.unified_attention( query, key, value, self.layer_name @@ -802,6 +816,55 @@ direct_register_custom_op( ) +def unified_kv_cache_update( + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + """ + Returns a dummy that is passed to unified_attention to signal a side effect and + the data dependency between them to ensure torch.compile preserves ordering. + """ + forward_context = get_forward_context() + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + + slot_mapping = forward_context.slot_mapping + assert isinstance(slot_mapping, dict), ( + f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " + ) + layer_slot_mapping = slot_mapping.get(layer_name) + if layer_slot_mapping is not None: + assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( + f"{attn_layer.impl.__class__.__name__} does not support kv cache update" + ) + attn_layer.impl.do_kv_cache_update( + attn_layer, + key, + value, + kv_cache, + layer_slot_mapping, + ) + + return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) + + +def unified_kv_cache_update_fake( + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty(0, device=key.device, dtype=key.dtype) + + +direct_register_custom_op( + op_name="unified_kv_cache_update", + op_func=unified_kv_cache_update, + fake_impl=unified_kv_cache_update_fake, + mutates_args=[], +) + + @maybe_transfer_kv_layer def unified_attention_with_output( query: torch.Tensor, @@ -811,7 +874,12 @@ def unified_attention_with_output( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: + # kv_cache_dummy_dep is not used but accepting it creates a data dependency + # that ensures torch.compile preserves ordering between KV cache update and + # attention forward. + del kv_cache_dummy_dep attn_metadata, self, kv_cache = get_attention_context(layer_name) self.impl.forward( @@ -835,6 +903,7 @@ def unified_attention_with_output_fake( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: return diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a856f6f31..301834d19 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -189,6 +189,7 @@ class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] + slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] """ Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata @@ -266,6 +267,7 @@ def create_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + slot_mapping: dict[str, torch.Tensor] | None = None, additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): @@ -282,6 +284,7 @@ def create_forward_context( remaining_moe_layers=remaining_moe_layers, virtual_engine=virtual_engine, attn_metadata=attn_metadata, + slot_mapping=slot_mapping or {}, dp_metadata=dp_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, @@ -316,6 +319,7 @@ def set_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, skip_compiled: bool = False, ): """A context manager that stores the current forward context, @@ -374,6 +378,7 @@ def set_forward_context( cudagraph_runtime_mode, batch_descriptor, ubatch_slices, + slot_mapping, additional_kwargs, skip_compiled, ) diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index a3f1f1072..f47fa1148 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -15,7 +15,7 @@ from vllm.v1.attention.backend import ( AttentionMetadata, AttentionType, CommonAttentionMetadata, - subclass_attention_backend, + subclass_attention_backend_with_overrides, ) from vllm.v1.attention.selector import get_attn_backend from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec @@ -72,6 +72,7 @@ def create_cross_attention_backend( ) -> type[AttentionBackend]: prefix = "CrossAttention_" underlying_builder = underlying_attn_backend.get_builder_cls() + underlying_impl = underlying_attn_backend.get_impl_cls() class CrossAttentionBuilder(underlying_builder): # type: ignore def build( @@ -106,18 +107,60 @@ def create_cross_attention_backend( ) # NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here - new_metadata.slot_mapping = _get_cross_slot_mapping( + slot_mapping = _get_cross_slot_mapping( new_metadata.encoder_seq_lens_cpu, new_metadata.block_table_tensor, self.kv_cache_spec, self.device, ) - return super().build(common_prefix_len, new_metadata, fast_build) + attn_metadata = super().build(common_prefix_len, new_metadata, fast_build) + attn_metadata.slot_mapping = slot_mapping + return attn_metadata - attn_backend = subclass_attention_backend( + # NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by + # `CrossAttentionBuilder` instead of the one computed by `BlockTable` + # (gpu_model_runner) + class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc] + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + if ( + not underlying_attn_backend.forward_includes_kv_cache_update + and attn_metadata is not None + ): + self.do_kv_cache_update( + layer, key, value, kv_cache, attn_metadata.slot_mapping + ) + + return super().forward( + layer, + query, + key, + value, + kv_cache, + attn_metadata, + output, + output_scale, + output_block_scale, + ) + + attn_backend = subclass_attention_backend_with_overrides( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=CrossAttentionBuilder, + overrides={ + "get_builder_cls": lambda: CrossAttentionBuilder, + "get_impl_cls": lambda: CrossAttentionImpl, + "forward_includes_kv_cache_update": True, + }, ) return attn_backend diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 3f13572f6..19833292c 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -613,8 +613,9 @@ def weak_ref_tensor(tensor: Any) -> Any: Create a weak reference to a tensor. The new tensor will share the same data as the original tensor, but will not keep the original tensor alive. + This ignores 0-size tensors as those don't allocate any memory. """ - if isinstance(tensor, torch.Tensor): + if isinstance(tensor, torch.Tensor) and tensor.numel() > 0: return torch.ops._C.weak_ref_tensor(tensor) else: return tensor diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 297171c38..32a143f8e 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -53,6 +53,9 @@ class AttentionBackend(ABC): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"] + # Does attention's forward() include kv cache update? + forward_includes_kv_cache_update: bool = True + @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [MultipleOf(1)] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 28e94bcd6..e88ee4de4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -79,6 +79,8 @@ class FlashAttentionBackend(AttentionBackend): return [16, 32, 64] return [MultipleOf(16)] + forward_includes_kv_cache_update: bool = False + @staticmethod def get_name() -> str: return "FLASH_ATTN" @@ -652,32 +654,6 @@ class FlashAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( @@ -774,6 +750,49 @@ class FlashAttentionImpl(AttentionImpl): ) return output + def do_kv_cache_update( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is not None + or key is None + or value is None + ): + return + + key_cache, value_cache = kv_cache.unbind(0) + + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + def _forward_with_dcp( self, query: torch.Tensor, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d32c8d021..4dff3fe70 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -159,6 +159,10 @@ class SpecDecodeBaseProposer: with_numpy=True, ) + self._slot_mapping_buffer = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) + # Determine allowed attention backends once during initialization. self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): @@ -230,6 +234,24 @@ class SpecDecodeBaseProposer: positions = positions[0] self.positions[:num_tokens] = positions + def _get_slot_mapping( + self, + num_tokens: int, + slot_mapping: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return slot_mapping dict for EAGLE layers. + + If slot_mapping is provided, copies it into the buffer first. + """ + if slot_mapping is not None: + num_actual = slot_mapping.shape[0] + self._slot_mapping_buffer[:num_actual].copy_(slot_mapping) + if num_tokens > num_actual: + self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID) + + view = self._slot_mapping_buffer[:num_tokens] + return {name: view for name in self.attn_layer_names + self.indexer_layer_names} + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: """Initialize cudagraph dispatcher keys for eagle. @@ -262,6 +284,9 @@ class SpecDecodeBaseProposer: sampling_metadata: SamplingMetadata, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None, + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, ) -> torch.Tensor: batch_size = common_attn_metadata.batch_size() @@ -358,6 +383,9 @@ class SpecDecodeBaseProposer: num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=self._get_slot_mapping( + num_input_tokens, common_attn_metadata.slot_mapping + ), ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): @@ -396,6 +424,7 @@ class SpecDecodeBaseProposer: positions=positions, hidden_states=hidden_states, common_attn_metadata=common_attn_metadata, + slot_mappings=slot_mappings, ) # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) @@ -553,6 +582,9 @@ class SpecDecodeBaseProposer: num_tokens=input_batch_size, num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=self._get_slot_mapping( + input_batch_size, common_attn_metadata.slot_mapping + ), ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): @@ -760,6 +792,9 @@ class SpecDecodeBaseProposer: # [num_tokens, hidden_size] hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, ) -> list[torch.Tensor]: tree_attn_metadata_builder = self.runner.attn_groups[0][ 0 @@ -881,6 +916,9 @@ class SpecDecodeBaseProposer: self.vllm_config, num_tokens=num_input_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=self._get_slot_mapping( + num_input_tokens, attn_metadata.slot_mapping + ), ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -1212,6 +1250,7 @@ class SpecDecodeBaseProposer: num_tokens: int, use_cudagraphs: bool = True, is_graph_capturing: bool = False, + slot_mappings: dict[str, torch.Tensor] | None = None, ) -> None: # FIXME: when using tree-based specdec, adjust number of forward-passes # according to the depth of the tree. @@ -1233,12 +1272,23 @@ class SpecDecodeBaseProposer: if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens + # Make sure to use EAGLE's own buffer during cudagraph capture. + if ( + self.attn_layer_names + and slot_mappings is not None + and self.attn_layer_names[0] in slot_mappings + ): + slot_mapping_dict = self._get_slot_mapping(num_input_tokens) + else: + slot_mapping_dict = slot_mappings or {} + with set_forward_context( None, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=slot_mapping_dict, ): if self.supports_mm_inputs: input_ids = None diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 989478f34..2e9330bf6 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,6 +38,9 @@ class MedusaProposer: self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, # unused ) -> torch.Tensor: # Generate blocks and compute logits blocks = self.model(target_hidden_states) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index f97d54e63..53199d0ce 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -3,6 +3,7 @@ import os import numpy as np +import torch from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig @@ -132,6 +133,9 @@ class NgramProposer: sampled_token_ids: list[list[int]], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, # unused ) -> list[list[int]]: # find which requests need ngram proposals valid_ngram_requests = [] diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 5d6dcc552..c5f8e6f86 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + from vllm.config import VllmConfig from vllm.v1.worker.gpu_input_batch import InputBatch @@ -33,6 +35,9 @@ class SuffixDecodingProposer: self, input_batch: InputBatch, sampled_token_ids: list[list[int]], + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, # unused ) -> list[list[int]]: """ Propose speculative tokens for each request in the input batch. Suffix Decoding diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index d568ccf1c..d5688158b 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -140,6 +140,18 @@ def init_kv_cache( return kv_caches +def build_slot_mappings_by_layer( + slot_mappings: torch.Tensor, + kv_cache_config: KVCacheConfig, +) -> dict[str, torch.Tensor]: + slot_mappings_by_layer: dict[str, torch.Tensor] = {} + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + slot_mapping = slot_mappings[i] + for layer_name in kv_cache_group.layer_names: + slot_mappings_by_layer[layer_name] = slot_mapping + return slot_mappings_by_layer + + def build_attn_metadata( attn_metadata_builders: list[AttentionMetadataBuilder], num_reqs: int, diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 4cb1ff3b9..3d626a2c5 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -14,7 +14,10 @@ from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import set_forward_context from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.worker.gpu.attn_utils import build_attn_metadata +from vllm.v1.worker.gpu.attn_utils import ( + build_attn_metadata, + build_slot_mappings_by_layer, +) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers @@ -88,7 +91,7 @@ class CudaGraphManager: positions = mrope_positions[:, :num_tokens] if inputs_embeds is not None: inputs_embeds = inputs_embeds[:num_tokens] - attn_metadata = prepare_inputs_to_capture( + attn_metadata, slot_mappings = prepare_inputs_to_capture( num_reqs, num_tokens, input_buffers, @@ -98,6 +101,9 @@ class CudaGraphManager: kv_cache_config, ) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, kv_cache_config + ) # Warm up. with set_forward_context( @@ -106,6 +112,7 @@ class CudaGraphManager: num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, + slot_mapping=slot_mappings_by_layer, ): hidden_states = model( input_ids=input_ids, @@ -125,6 +132,7 @@ class CudaGraphManager: num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, + slot_mapping=slot_mappings_by_layer, ), torch.cuda.graph(graph, self.pool), ): @@ -244,7 +252,7 @@ def prepare_inputs_to_capture( attn_metadata_builders: list[AttentionMetadataBuilder], max_model_len: int, kv_cache_config: KVCacheConfig, -) -> dict[str, Any]: +) -> tuple[dict[str, Any], torch.Tensor]: num_tokens_per_req = num_tokens // num_reqs query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req @@ -274,4 +282,4 @@ def prepare_inputs_to_capture( slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, ) - return attn_metadata + return attn_metadata, slot_mappings diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 8b100da28..75d4c4e00 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -24,6 +24,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, + build_slot_mappings_by_layer, get_kv_cache_spec, init_attn_backend, init_kv_cache, @@ -881,6 +882,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: assert input_batch.mrope_positions is not None positions = input_batch.mrope_positions + slot_mappings = self.block_tables.compute_slot_mappings( + input_batch.idx_mapping, + input_batch.query_start_loc, + input_batch.positions[: input_batch.num_tokens], + ) + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) with set_forward_context( input_batch.attn_metadata, self.vllm_config, @@ -888,6 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): Support piecewise CUDA graph. cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, + slot_mapping=slot_mappings_by_layer, ): self.kv_connector.pre_forward(scheduler_output) hidden_states = self.model( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 863ed5db9..b46fc175d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -314,6 +314,7 @@ class ExecuteModelState(NamedTuple): aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None + slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None class GPUModelRunner( @@ -1595,6 +1596,7 @@ class GPUModelRunner( for_cudagraph_capture: bool = False, num_scheduled_tokens: dict[str, int] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None, + slot_mappings: dict[int, torch.Tensor] | None = None, ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: """ :return: tuple[attn_metadata, spec_decode_common_attn_metadata] @@ -1628,7 +1630,7 @@ class GPUModelRunner( kv_cache_groups = self.kv_cache_config.kv_cache_groups - def _get_block_table_and_slot_mapping(kv_cache_gid: int): + def _get_block_table(kv_cache_gid: int): assert num_reqs_padded is not None and num_tokens_padded is not None kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): @@ -1637,24 +1639,19 @@ class GPUModelRunner( dtype=torch.int32, device=self.device, ) - slot_mapping = torch.zeros( - (num_tokens_padded,), - dtype=torch.int64, - device=self.device, - ) else: blk_table = self.input_batch.block_table[kv_cache_gid] blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) - slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID - slot_mapping[num_tokens:num_tokens_padded].fill_(-1) blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + return blk_table_tensor - return blk_table_tensor, slot_mapping + assert slot_mappings is not None + block_table_gid_0 = _get_block_table(0) + slot_mapping_gid_0 = slot_mappings[0] - block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) if self.model_config.enable_return_routed_experts: self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() cm_base = CommonAttentionMetadata( @@ -1779,9 +1776,8 @@ class GPUModelRunner( num_reqs_padded, ) if kv_cache_gid > 0: - cm.block_table_tensor, cm.slot_mapping = ( - _get_block_table_and_slot_mapping(kv_cache_gid) - ) + cm.block_table_tensor = _get_block_table(kv_cache_gid) + cm.slot_mapping = slot_mappings[kv_cache_gid] if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): @@ -3119,6 +3115,80 @@ class GPUModelRunner( pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) self.layerwise_nvtx_hooks_registered = True + def _get_slot_mappings( + self, + num_tokens_padded: int, + num_reqs_padded: int, + num_tokens_unpadded: int, + ubatch_slices: "UBatchSlices | None" = None, + ) -> tuple[ + dict[int, torch.Tensor] | None, + dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + ]: + """ + Build slot mappings in both formats needed by the system. + + Args: + num_tokens_padded: Total number of tokens (padded) + num_reqs_padded: Total number of requests (padded) + num_tokens_unpadded: Actual number of tokens (unpadded) + ubatch_slices: Optional ubatch slicing info for DBO + + Returns: + A tuple of: + - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata + - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext + """ + if not ( + hasattr(self, "kv_cache_config") + and self.kv_cache_config is not None + and len(self.kv_cache_config.kv_cache_groups) > 0 + ): + return None, None + + def _get_slot_mapping(kv_cache_gid: int): + assert num_reqs_padded is not None and num_tokens_padded is not None + kv_cache_spec = self.kv_cache_config.kv_cache_groups[ + kv_cache_gid + ].kv_cache_spec + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + slot_mapping = torch.zeros( + (num_tokens_padded,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_gid] + slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1) + + return slot_mapping + + slot_mappings_by_gid = { + gid: _get_slot_mapping(gid) + for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups) + } + + slot_mappings_by_layer: dict[str, torch.Tensor] = {} + for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups): + slot_mapping = slot_mappings_by_gid[gid] + for layer_name in kv_cache_group.layer_names: + slot_mappings_by_layer[layer_name] = slot_mapping + + if ubatch_slices is not None: + result: list[dict[str, torch.Tensor]] = [] + for ubatch in ubatch_slices: + sliced_mappings: dict[str, torch.Tensor] = {} + for layer_name, slot_mapping in slot_mappings_by_layer.items(): + sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice] + result.append(sliced_mappings) + return slot_mappings_by_gid, result + + return slot_mappings_by_gid, slot_mappings_by_layer + @torch.inference_mode() def execute_model( self, @@ -3248,6 +3318,17 @@ class GPUModelRunner( ubatch_slices_padded, ) + # True if any attention backend handles KV cache update separately + # from forward() (i.e., forward_includes_kv_cache_update=False). When true, + # slot_mappings must use padded dimensions to match the key/value tensors. + has_separate_kv_update = not all( + all( + g.backend.forward_includes_kv_cache_update + for g in self.attn_groups[id] + ) + for id, spec in enumerate(self.kv_cache_config.kv_cache_groups) + if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec) + ) pad_attn = cudagraph_mode == CUDAGraphMode.FULL if self.cache_config.mamba_cache_mode == "align": @@ -3265,6 +3346,17 @@ class GPUModelRunner( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens_padded + if pad_attn or has_separate_kv_update + else num_tokens_unpadded, + num_reqs_padded=( + num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs + ), + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + attn_metadata, spec_decode_common_attn_metadata = ( self._build_attention_metadata( num_tokens=num_tokens_unpadded, @@ -3277,6 +3369,7 @@ class GPUModelRunner( use_spec_decode=use_spec_decode, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, ) ) @@ -3317,6 +3410,7 @@ class GPUModelRunner( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), @@ -3399,6 +3493,7 @@ class GPUModelRunner( aux_hidden_states, ec_connector_output, cudagraph_stats, + slot_mappings, ) self.kv_connector_output = kv_connector_output return None @@ -3435,6 +3530,7 @@ class GPUModelRunner( aux_hidden_states, ec_connector_output, cudagraph_stats, + slot_mappings, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -3468,6 +3564,7 @@ class GPUModelRunner( aux_hidden_states, spec_decode_metadata, spec_decode_common_attn_metadata, + slot_mappings, ) self._copy_draft_token_ids_to_cpu(scheduler_output) @@ -3676,6 +3773,7 @@ class GPUModelRunner( aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, + slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config @@ -3687,11 +3785,14 @@ class GPUModelRunner( sampled_token_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, + slot_mappings=slot_mappings, ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) - draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) + draft_token_ids = self.drafter.propose( + self.input_batch, sampled_token_ids, slot_mappings=slot_mappings + ) elif spec_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -3716,6 +3817,7 @@ class GPUModelRunner( draft_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, + slot_mappings=slot_mappings, ) elif spec_config.use_eagle() or spec_config.uses_draft_model(): assert isinstance(self.drafter, EagleProposer | DraftModelProposer) @@ -3826,6 +3928,7 @@ class GPUModelRunner( common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, + slot_mappings=slot_mappings, ) return draft_token_ids @@ -4406,6 +4509,13 @@ class GPUModelRunner( attn_metadata: PerLayerAttnMetadata | None = None + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs_padded, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: @@ -4431,6 +4541,7 @@ class GPUModelRunner( max_query_len=max_query_len, ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, ) with self.maybe_dummy_run_with_lora( @@ -4499,6 +4610,7 @@ class GPUModelRunner( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, ), ): outputs = self.model( @@ -4545,6 +4657,7 @@ class GPUModelRunner( num_tokens, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, + slot_mappings=slot_mappings, ) # We register layerwise NVTX hooks here after the first dynamo tracing is diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index af09129e6..97f8b92ce 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -295,6 +295,7 @@ class UBatchWrapper: self, ubatch_slices, attn_metadata, + slot_mapping, input_ids, positions, inputs_embeds, @@ -306,6 +307,9 @@ class UBatchWrapper: ) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] + # slot_mapping can be None, an empty dict (from create_forward_context + # converting None to {}), or a list of dicts (one per ubatch) + has_slot_mapping = slot_mapping and isinstance(slot_mapping, list) for i, ubatch_slice in enumerate(ubatch_slices): forward_contexts.append( create_forward_context( @@ -314,6 +318,7 @@ class UBatchWrapper: dp_metadata=dp_metadata[i], batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=slot_mapping[i] if has_slot_mapping else None, ) ) @@ -406,6 +411,7 @@ class UBatchWrapper: return self.cudagraph_wrapper(*args, **kwargs) attn_metadata = forward_context.attn_metadata + slot_mapping = forward_context.slot_mapping num_tokens = ( ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start ) * 2 @@ -440,6 +446,7 @@ class UBatchWrapper: ubatch_metadata = self._make_ubatch_metadata( ubatch_slices=ubatch_slices, attn_metadata=attn_metadata, + slot_mapping=slot_mapping, input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -462,6 +469,7 @@ class UBatchWrapper: ubatch_metadata = self._make_ubatch_metadata( ubatch_slices=ubatch_slices, attn_metadata=attn_metadata, + slot_mapping=slot_mapping, input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors,