[NIXL][BUG] Fix Triton heterogeneous TP (#37940)
Signed-off-by: Yifan <yzong@redhat.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -23,6 +23,10 @@ hybrid_ssm_configs=(
|
|||||||
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
|
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
|
||||||
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||||
)
|
)
|
||||||
|
sw_attn_configs=(
|
||||||
|
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||||
|
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
|
||||||
|
)
|
||||||
|
|
||||||
# Select config array based on DP_EP env var
|
# Select config array based on DP_EP env var
|
||||||
if [[ -n "${DP_EP:-}" ]]; then
|
if [[ -n "${DP_EP:-}" ]]; then
|
||||||
@@ -31,6 +35,9 @@ if [[ -n "${DP_EP:-}" ]]; then
|
|||||||
elif [[ -n "${HYBRID_SSM:-}" ]]; then
|
elif [[ -n "${HYBRID_SSM:-}" ]]; then
|
||||||
configs=("${hybrid_ssm_configs[@]}")
|
configs=("${hybrid_ssm_configs[@]}")
|
||||||
echo "HYBRID_SSM is set, using hybrid_ssm_configs."
|
echo "HYBRID_SSM is set, using hybrid_ssm_configs."
|
||||||
|
elif [[ -n "${SW_ATTN:-}" ]]; then
|
||||||
|
configs=("${sw_attn_configs[@]}")
|
||||||
|
echo "SW_ATTN is set, using sw_attn_configs."
|
||||||
else
|
else
|
||||||
configs=("${tp_configs[@]}")
|
configs=("${tp_configs[@]}")
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -1586,10 +1586,18 @@ def test_register_kv_caches(
|
|||||||
expected_base_addrs: list[int]
|
expected_base_addrs: list[int]
|
||||||
expected_num_entries: int
|
expected_num_entries: int
|
||||||
kv_caches: dict[str, torch.Tensor]
|
kv_caches: dict[str, torch.Tensor]
|
||||||
assert str(enable_cross_layers).lower() != "true" or (
|
if str(enable_cross_layers).lower() == "true":
|
||||||
(attn_backend not in ("FLASH_ATTN", "FLASHINFER"))
|
assert connector.prefer_cross_layer_blocks == (
|
||||||
or connector.prefer_cross_layer_blocks
|
attn_backend in ("FLASH_ATTN", "FLASHINFER", "TRITON_ATTN")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert not connector.prefer_cross_layer_blocks
|
||||||
|
|
||||||
|
test_shape = backend_cls.get_kv_cache_shape(
|
||||||
|
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||||
)
|
)
|
||||||
|
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
|
||||||
|
|
||||||
if connector.prefer_cross_layer_blocks:
|
if connector.prefer_cross_layer_blocks:
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
_, cross_layers_kv_cache, _ = (
|
_, cross_layers_kv_cache, _ = (
|
||||||
@@ -1619,7 +1627,7 @@ def test_register_kv_caches(
|
|||||||
]
|
]
|
||||||
expected_num_entries = 1
|
expected_num_entries = 1
|
||||||
|
|
||||||
expected_blocks_count = 8
|
expected_blocks_count = num_blocks * (2 if is_blocks_first else 1)
|
||||||
|
|
||||||
kv_caches = {"all-layers": cross_layers_kv_cache}
|
kv_caches = {"all-layers": cross_layers_kv_cache}
|
||||||
else:
|
else:
|
||||||
@@ -1639,12 +1647,6 @@ def test_register_kv_caches(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Store tensor info for validation
|
# Store tensor info for validation
|
||||||
|
|
||||||
test_shape = backend_cls.get_kv_cache_shape(
|
|
||||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
|
||||||
)
|
|
||||||
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
|
|
||||||
|
|
||||||
if is_blocks_first:
|
if is_blocks_first:
|
||||||
expected_tensor_size = (
|
expected_tensor_size = (
|
||||||
shared_tensor.element_size() * shared_tensor.numel()
|
shared_tensor.element_size() * shared_tensor.numel()
|
||||||
@@ -1696,13 +1698,13 @@ def test_register_kv_caches(
|
|||||||
|
|
||||||
if connector.prefer_cross_layer_blocks:
|
if connector.prefer_cross_layer_blocks:
|
||||||
num_blocks = 8
|
num_blocks = 8
|
||||||
expected_block_len = expected_tensor_size // num_blocks
|
|
||||||
else:
|
else:
|
||||||
num_blocks = kv_cache_config.num_blocks
|
num_blocks = kv_cache_config.num_blocks
|
||||||
if is_blocks_first:
|
|
||||||
expected_block_len = expected_tensor_size // num_blocks // 2
|
if is_blocks_first:
|
||||||
else:
|
expected_block_len = expected_tensor_size // num_blocks // 2
|
||||||
expected_block_len = expected_tensor_size // num_blocks
|
else:
|
||||||
|
expected_block_len = expected_tensor_size // num_blocks
|
||||||
|
|
||||||
for i, block_entry in enumerate(blocks_data):
|
for i, block_entry in enumerate(blocks_data):
|
||||||
block_start_addr, block_len, tp_rank = block_entry
|
block_start_addr, block_len, tp_rank = block_entry
|
||||||
|
|||||||
@@ -330,6 +330,7 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
|
|||||||
if backend.get_name() not in (
|
if backend.get_name() not in (
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
"FLASHINFER",
|
"FLASHINFER",
|
||||||
|
"TRITON_ATTN",
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -2118,6 +2119,21 @@ class NixlConnectorWorker:
|
|||||||
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Heterogeneous TP requires head-splitting, which only works with
|
||||||
|
# HND layout. MLA and replicated-KV cases don't split on heads.
|
||||||
|
# Mamba doesn't support heterogeneous TP.
|
||||||
|
if (
|
||||||
|
abs(tp_ratio) != 1
|
||||||
|
and not self.use_mla
|
||||||
|
and not self.kv_topo.is_kv_replicated(remote_engine_id)
|
||||||
|
and kv_cache_layout != "HND"
|
||||||
|
and not self.enable_permute_local_kv
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Heterogeneous TP head-dimension splitting requires contiguous heads. "
|
||||||
|
"Use HND layout on the prefill side."
|
||||||
|
)
|
||||||
|
|
||||||
# Block len can only vary across layers when using MLA.
|
# Block len can only vary across layers when using MLA.
|
||||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||||
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
|
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from vllm.v1.attention.backend import (
|
|||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
triton_reshape_and_cache_flash,
|
triton_reshape_and_cache_flash,
|
||||||
@@ -309,12 +310,20 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
# `stride_order` indicates the permutation that gets
|
# `stride_order` indicates the permutation that gets
|
||||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||||
if include_num_layers_dimension:
|
cache_layout = get_kv_cache_layout()
|
||||||
|
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||||
return (1, 0, 2, 3, 4, 5)
|
return (1, 0, 2, 3, 4, 5)
|
||||||
|
elif cache_layout == "NHD":
|
||||||
# (num_blocks, 2, block_size, num_kv_heads, head_size)
|
stride_order = (0, 1, 2, 3, 4)
|
||||||
return (0, 1, 2, 3, 4)
|
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||||
|
# (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
|
||||||
|
return (1, 2, 4, 0, 3, 5)
|
||||||
|
elif cache_layout == "HND":
|
||||||
|
stride_order = (0, 1, 3, 2, 4)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown cache layout: {cache_layout}")
|
||||||
|
return stride_order
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||||
|
|||||||
@@ -70,9 +70,15 @@ def reshape_and_cache_kernel_flash(
|
|||||||
+ (cur_dim % x)
|
+ (cur_dim % x)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tgt_base = block_idx * block_stride + block_offset * page_stride
|
cur_head = tile_pos // head_size
|
||||||
tgt_idx_k = tgt_base + tile_pos
|
cur_dim = tile_pos % head_size
|
||||||
tgt_idx_v = tgt_base + tile_pos
|
tgt_idx_k = (
|
||||||
|
block_idx * block_stride
|
||||||
|
+ block_offset * page_stride
|
||||||
|
+ cur_head * head_stride
|
||||||
|
+ cur_dim
|
||||||
|
)
|
||||||
|
tgt_idx_v = tgt_idx_k
|
||||||
|
|
||||||
# [TILE_SIZE]
|
# [TILE_SIZE]
|
||||||
key_load = tl.load(
|
key_load = tl.load(
|
||||||
|
|||||||
Reference in New Issue
Block a user