[P/D] Heterogeneous TP (#18833)
Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# NOTE When running disaggregated PD with NIXL, HND layout is used for
|
||||
# faster transfer. `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError("Unknown cache layout format %s.", cache_layout)
|
||||
return stride_order
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
|
||||
Reference in New Issue
Block a user