From dc0428ebb879af3224e5d9ef5ab93cb04d8d1b27 Mon Sep 17 00:00:00 2001 From: yzong-rh Date: Wed, 1 Apr 2026 11:23:15 -0400 Subject: [PATCH] [NIXL][BUG] Fix Triton heterogeneous TP (#37940) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Yifan Co-authored-by: Nicolò Lucchesi --- .../config_sweep_accuracy_test.sh | 7 ++++ .../kv_connector/unit/test_nixl_connector.py | 32 ++++++++++--------- .../kv_connector/v1/nixl_connector.py | 16 ++++++++++ vllm/v1/attention/backends/triton_attn.py | 17 +++++++--- .../ops/triton_reshape_and_cache_flash.py | 12 +++++-- 5 files changed, 62 insertions(+), 22 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index 92ab254dd..fe79a99fc 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -23,6 +23,10 @@ hybrid_ssm_configs=( # TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models. "ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling" ) +sw_attn_configs=( + "ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" + "ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" +) # Select config array based on DP_EP env var if [[ -n "${DP_EP:-}" ]]; then @@ -31,6 +35,9 @@ if [[ -n "${DP_EP:-}" ]]; then elif [[ -n "${HYBRID_SSM:-}" ]]; then configs=("${hybrid_ssm_configs[@]}") echo "HYBRID_SSM is set, using hybrid_ssm_configs." +elif [[ -n "${SW_ATTN:-}" ]]; then + configs=("${sw_attn_configs[@]}") + echo "SW_ATTN is set, using sw_attn_configs." else configs=("${tp_configs[@]}") fi diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b4ee97cd1..17a70a3a5 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1586,10 +1586,18 @@ def test_register_kv_caches( expected_base_addrs: list[int] expected_num_entries: int kv_caches: dict[str, torch.Tensor] - assert str(enable_cross_layers).lower() != "true" or ( - (attn_backend not in ("FLASH_ATTN", "FLASHINFER")) - or connector.prefer_cross_layer_blocks + if str(enable_cross_layers).lower() == "true": + assert connector.prefer_cross_layer_blocks == ( + attn_backend in ("FLASH_ATTN", "FLASHINFER", "TRITON_ATTN") + ) + else: + assert not connector.prefer_cross_layer_blocks + + test_shape = backend_cls.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 ) + is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 + if connector.prefer_cross_layer_blocks: with set_current_vllm_config(vllm_config): _, cross_layers_kv_cache, _ = ( @@ -1619,7 +1627,7 @@ def test_register_kv_caches( ] expected_num_entries = 1 - expected_blocks_count = 8 + expected_blocks_count = num_blocks * (2 if is_blocks_first else 1) kv_caches = {"all-layers": cross_layers_kv_cache} else: @@ -1639,12 +1647,6 @@ def test_register_kv_caches( } # Store tensor info for validation - - test_shape = backend_cls.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 - if is_blocks_first: expected_tensor_size = ( shared_tensor.element_size() * shared_tensor.numel() @@ -1696,13 +1698,13 @@ def test_register_kv_caches( if connector.prefer_cross_layer_blocks: num_blocks = 8 - expected_block_len = expected_tensor_size // num_blocks else: num_blocks = kv_cache_config.num_blocks - if is_blocks_first: - expected_block_len = expected_tensor_size // num_blocks // 2 - else: - expected_block_len = expected_tensor_size // num_blocks + + if is_blocks_first: + expected_block_len = expected_tensor_size // num_blocks // 2 + else: + expected_block_len = expected_tensor_size // num_blocks for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a86a52a6a..0aaf3b6e9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -330,6 +330,7 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA): if backend.get_name() not in ( "FLASH_ATTN", "FLASHINFER", + "TRITON_ATTN", ): return False @@ -2118,6 +2119,21 @@ class NixlConnectorWorker: "setting 'enable_permute_local_kv'=True in --kv-transfer-config." ) + # Heterogeneous TP requires head-splitting, which only works with + # HND layout. MLA and replicated-KV cases don't split on heads. + # Mamba doesn't support heterogeneous TP. + if ( + abs(tp_ratio) != 1 + and not self.use_mla + and not self.kv_topo.is_kv_replicated(remote_engine_id) + and kv_cache_layout != "HND" + and not self.enable_permute_local_kv + ): + raise RuntimeError( + "Heterogeneous TP head-dimension splitting requires contiguous heads. " + "Use HND layout on the prefill side." + ) + # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f9a688f65..3dd081745 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -29,6 +29,7 @@ from vllm.v1.attention.backend import ( CommonAttentionMetadata, MultipleOf, ) +from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, @@ -309,12 +310,20 @@ class TritonAttentionBackend(AttentionBackend): ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. - if include_num_layers_dimension: + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD" and include_num_layers_dimension: # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) return (1, 0, 2, 3, 4, 5) - - # (num_blocks, 2, block_size, num_kv_heads, head_size) - return (0, 1, 2, 3, 4) + elif cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size) + return (1, 2, 4, 0, 3, 5) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout: {cache_layout}") + return stride_order @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: diff --git a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py index f98b50f8f..eeec60962 100644 --- a/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/v1/attention/ops/triton_reshape_and_cache_flash.py @@ -70,9 +70,15 @@ def reshape_and_cache_kernel_flash( + (cur_dim % x) ) else: - tgt_base = block_idx * block_stride + block_offset * page_stride - tgt_idx_k = tgt_base + tile_pos - tgt_idx_v = tgt_base + tile_pos + cur_head = tile_pos // head_size + cur_dim = tile_pos % head_size + tgt_idx_k = ( + block_idx * block_stride + + block_offset * page_stride + + cur_head * head_stride + + cur_dim + ) + tgt_idx_v = tgt_idx_k # [TILE_SIZE] key_load = tl.load(