[NIXL][BUG] Fix Triton heterogeneous TP (#37940)

Signed-off-by: Yifan <yzong@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
yzong-rh
2026-04-01 11:23:15 -04:00
committed by GitHub
parent 148c2072ec
commit dc0428ebb8
5 changed files with 62 additions and 22 deletions

View File

@@ -23,6 +23,10 @@ hybrid_ssm_configs=(
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
sw_attn_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
)
# Select config array based on DP_EP env var
if [[ -n "${DP_EP:-}" ]]; then
@@ -31,6 +35,9 @@ if [[ -n "${DP_EP:-}" ]]; then
elif [[ -n "${HYBRID_SSM:-}" ]]; then
configs=("${hybrid_ssm_configs[@]}")
echo "HYBRID_SSM is set, using hybrid_ssm_configs."
elif [[ -n "${SW_ATTN:-}" ]]; then
configs=("${sw_attn_configs[@]}")
echo "SW_ATTN is set, using sw_attn_configs."
else
configs=("${tp_configs[@]}")
fi

View File

@@ -1586,10 +1586,18 @@ def test_register_kv_caches(
expected_base_addrs: list[int]
expected_num_entries: int
kv_caches: dict[str, torch.Tensor]
assert str(enable_cross_layers).lower() != "true" or (
(attn_backend not in ("FLASH_ATTN", "FLASHINFER"))
or connector.prefer_cross_layer_blocks
if str(enable_cross_layers).lower() == "true":
assert connector.prefer_cross_layer_blocks == (
attn_backend in ("FLASH_ATTN", "FLASHINFER", "TRITON_ATTN")
)
else:
assert not connector.prefer_cross_layer_blocks
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if connector.prefer_cross_layer_blocks:
with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
@@ -1619,7 +1627,7 @@ def test_register_kv_caches(
]
expected_num_entries = 1
expected_blocks_count = 8
expected_blocks_count = num_blocks * (2 if is_blocks_first else 1)
kv_caches = {"all-layers": cross_layers_kv_cache}
else:
@@ -1639,12 +1647,6 @@ def test_register_kv_caches(
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = (
shared_tensor.element_size() * shared_tensor.numel()
@@ -1696,9 +1698,9 @@ def test_register_kv_caches(
if connector.prefer_cross_layer_blocks:
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
else:
num_blocks = kv_cache_config.num_blocks
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:

View File

@@ -330,6 +330,7 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
if backend.get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
"TRITON_ATTN",
):
return False
@@ -2118,6 +2119,21 @@ class NixlConnectorWorker:
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
# Heterogeneous TP requires head-splitting, which only works with
# HND layout. MLA and replicated-KV cases don't split on heads.
# Mamba doesn't support heterogeneous TP.
if (
abs(tp_ratio) != 1
and not self.use_mla
and not self.kv_topo.is_kv_replicated(remote_engine_id)
and kv_cache_layout != "HND"
and not self.enable_permute_local_kv
):
raise RuntimeError(
"Heterogeneous TP head-dimension splitting requires contiguous heads. "
"Use HND layout on the prefill side."
)
# Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):

View File

@@ -29,6 +29,7 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
@@ -309,12 +310,20 @@ class TritonAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
if include_num_layers_dimension:
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4, 5)
# (num_blocks, 2, block_size, num_kv_heads, head_size)
return (0, 1, 2, 3, 4)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 4, 0, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout: {cache_layout}")
return stride_order
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:

View File

@@ -70,9 +70,15 @@ def reshape_and_cache_kernel_flash(
+ (cur_dim % x)
)
else:
tgt_base = block_idx * block_stride + block_offset * page_stride
tgt_idx_k = tgt_base + tile_pos
tgt_idx_v = tgt_base + tile_pos
cur_head = tile_pos // head_size
cur_dim = tile_pos % head_size
tgt_idx_k = (
block_idx * block_stride
+ block_offset * page_stride
+ cur_head * head_stride
+ cur_dim
)
tgt_idx_v = tgt_idx_k
# [TILE_SIZE]
key_load = tl.load(