[PD][Nixl] Add support for hybrid SSM-FA models (#36687)
This commit is contained in:
@@ -18,11 +18,19 @@ dp_ep_configs=(
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
|
||||
)
|
||||
hybrid_ssm_configs=(
|
||||
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
|
||||
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
|
||||
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
|
||||
)
|
||||
|
||||
# Select config array based on DP_EP env var
|
||||
if [[ -n "${DP_EP:-}" ]]; then
|
||||
configs=("${dp_ep_configs[@]}")
|
||||
echo "DP_EP is set, using dp_ep_configs"
|
||||
elif [[ -n "${HYBRID_SSM:-}" ]]; then
|
||||
configs=("${hybrid_ssm_configs[@]}")
|
||||
echo "HYBRID_SSM is set, using hybrid_ssm_configs."
|
||||
else
|
||||
configs=("${tp_configs[@]}")
|
||||
fi
|
||||
|
||||
@@ -18,6 +18,7 @@ EXPECTED_VALUES = {
|
||||
"deepseek-ai/deepseek-vl2-tiny": 0.19,
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
|
||||
"google/gemma-3-4b-it": 0.74,
|
||||
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": 0.84,
|
||||
}
|
||||
|
||||
SIMPLE_PROMPT = (
|
||||
|
||||
@@ -53,7 +53,13 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheTensor,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@@ -332,8 +338,20 @@ def test_kv_transfer_handshake(dist_init):
|
||||
|
||||
# Prefill connector will register KV cache to populate proper handshake
|
||||
# metadata.
|
||||
# TODO this must match with values used in kv cache config
|
||||
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
|
||||
kv_cache_groups = [
|
||||
KVCacheGroupSpec(
|
||||
["layer0", "layer1", "layer2"],
|
||||
FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=4,
|
||||
head_size=16,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
)
|
||||
]
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=2, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
|
||||
)
|
||||
prefill_connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, kv_cache_config
|
||||
)
|
||||
@@ -437,7 +455,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
self.kv_cache_layout = kv_cache_layout
|
||||
# Mock register_kv_caches attribute needed for tests that do not call it.
|
||||
self.src_xfer_handles_by_block_size = {self.block_size: 1}
|
||||
test_shape = self.attn_backend.get_kv_cache_shape(
|
||||
test_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
self.kv_topo = TpKVTopology(
|
||||
@@ -447,7 +465,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
remote_block_size=self._block_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=self.attn_backend,
|
||||
attn_backends=self.attn_backends,
|
||||
tensor_shape=test_shape,
|
||||
)
|
||||
|
||||
@@ -501,6 +519,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
block_size=self.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
),
|
||||
remote_tp_rank=remote_tp_rank,
|
||||
remote_tp_size=remote_tp_size,
|
||||
@@ -951,6 +970,7 @@ class TestNixlHandshake:
|
||||
block_lens=worker.block_len_per_layer,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
block_size=worker.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -1006,6 +1026,7 @@ class TestNixlHandshake:
|
||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||
kv_cache_layout="HND",
|
||||
block_size=worker.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
@@ -1496,9 +1517,47 @@ def test_register_kv_caches(
|
||||
# test run if not mocking.
|
||||
mock_get_attn_backend.return_value = backend_cls
|
||||
mock_get_attn_backends.return_value = [backend_cls]
|
||||
num_layers = 32
|
||||
block_size = 16
|
||||
num_blocks = 8
|
||||
num_heads = 4
|
||||
head_size = 16
|
||||
|
||||
# TODO (NickLucche) the fact that connector depends on kv_cache_config for init
|
||||
# but cross-layer preference cant be inferred prior to creating kv_cache_config
|
||||
# is a bit awkward.
|
||||
dummy_connector = NixlConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
make_kv_cache_config(block_size=block_size),
|
||||
)
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=num_heads,
|
||||
head_size=head_size,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
if dummy_connector.prefer_cross_layer_blocks:
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=kv_cache_spec.page_size_bytes * num_blocks,
|
||||
shared_by=["all-layers"],
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
],
|
||||
kv_cache_groups=[KVCacheGroupSpec(["all-layers"], kv_cache_spec)],
|
||||
)
|
||||
else:
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec)
|
||||
],
|
||||
)
|
||||
# Create connector
|
||||
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config,
|
||||
@@ -1526,35 +1585,6 @@ def test_register_kv_caches(
|
||||
or connector.prefer_cross_layer_blocks
|
||||
)
|
||||
if connector.prefer_cross_layer_blocks:
|
||||
num_layers = 32
|
||||
block_size = 16
|
||||
num_blocks = 8
|
||||
# Keep the fake worker's expected num_blocks in sync with the
|
||||
# cross-layer tensor we are about to register.
|
||||
worker_kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size, num_blocks=num_blocks
|
||||
)
|
||||
connector.connector_worker.kv_cache_config = worker_kv_cache_config
|
||||
connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks
|
||||
kv_cache_spec = AttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=4,
|
||||
head_size=64,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=kv_cache_spec.page_size_bytes * num_blocks,
|
||||
shared_by=["dummy-layer"],
|
||||
)
|
||||
for i in range(num_layers)
|
||||
],
|
||||
# allocate_uniform_kv_caches does not use this
|
||||
kv_cache_groups=[],
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
_, cross_layers_kv_cache, _ = (
|
||||
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
|
||||
@@ -1586,12 +1616,8 @@ def test_register_kv_caches(
|
||||
expected_blocks_count = 8
|
||||
|
||||
kv_caches = {"all-layers": cross_layers_kv_cache}
|
||||
|
||||
else:
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_spec = cast(
|
||||
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
)
|
||||
kv_cache_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=kv_cache_config.num_blocks,
|
||||
block_size=kv_cache_spec.block_size,
|
||||
@@ -2261,7 +2287,7 @@ def test_compatibility_hash_validation(
|
||||
kv_cache_spec = cast(
|
||||
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
)
|
||||
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
|
||||
kv_cache_shape = decode_worker.attn_backends[0].get_kv_cache_shape(
|
||||
num_blocks=kv_cache_config.num_blocks,
|
||||
block_size=kv_cache_spec.block_size,
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads,
|
||||
@@ -2269,10 +2295,14 @@ def test_compatibility_hash_validation(
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
|
||||
# Build kv_caches from the actual layer names in kv_cache_config so that
|
||||
# _layer_specs lookups in register_kv_caches always find a matching key.
|
||||
layer_names = [
|
||||
name for group in kv_cache_config.kv_cache_groups for name in group.layer_names
|
||||
]
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
name: shared_tensor if i % 2 == 0 else unique_tensor
|
||||
for i, name in enumerate(layer_names)
|
||||
}
|
||||
decode_connector.register_kv_caches(kv_caches)
|
||||
|
||||
@@ -2312,6 +2342,7 @@ def test_compatibility_hash_validation(
|
||||
block_lens=[4096 * prefill_block_size], # slot_size * block_size
|
||||
kv_cache_layout="HND",
|
||||
block_size=prefill_block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
)
|
||||
handshake_payload = NixlHandshakePayload(
|
||||
compatibility_hash=remote_hash,
|
||||
@@ -2391,7 +2422,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
|
||||
remote_block_size=decode_worker._block_size, # shared state
|
||||
is_mla=decode_worker.use_mla,
|
||||
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
attn_backends=[backend],
|
||||
tensor_shape=test_shape,
|
||||
)
|
||||
|
||||
|
||||
@@ -74,6 +74,8 @@ def test_logical_to_kernel_block_ids_with_hma():
|
||||
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
|
||||
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
|
||||
worker._physical_blocks_per_logical_kv_block = 2
|
||||
# FA + SW groups (neither is MambaSpec, so both get expanded)
|
||||
worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True)
|
||||
|
||||
# Test conversion: FA + SW group
|
||||
logical_block_ids = [[0, 1, 2], [3, 4]]
|
||||
@@ -201,3 +203,113 @@ def test_nixl_metadata_hma_block_ids_structure():
|
||||
assert len(req_meta.remote.block_ids) == 2
|
||||
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
|
||||
assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21]
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_get_block_descs_ids_hybrid_ssm():
|
||||
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
|
||||
when ratio=1 (no kernel block size mismatch)."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorWorker,
|
||||
)
|
||||
|
||||
worker = object.__new__(NixlConnectorWorker)
|
||||
|
||||
num_blocks = 100
|
||||
engine_id = "test-engine"
|
||||
worker.num_regions = 2
|
||||
worker.dst_num_blocks = {engine_id: num_blocks}
|
||||
worker._has_mamba = True
|
||||
worker._is_mamba_group = [False, True]
|
||||
worker._physical_blocks_per_logical_kv_block = 1
|
||||
# num_descs = num_regions * num_blocks (no blocks_first doubling)
|
||||
worker.num_descs = 2 * num_blocks
|
||||
|
||||
fa_blocks = [3, 5]
|
||||
ssm_blocks = [1, 2]
|
||||
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))
|
||||
|
||||
# FA group: stride=num_blocks=100, offset=0
|
||||
# region0: [3, 5], region1: [103, 105]
|
||||
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
|
||||
# offset=num_descs=200
|
||||
# region0: [201, 202], region1: [301, 302]
|
||||
expected = [3, 5, 103, 105, 201, 202, 301, 302]
|
||||
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_get_block_descs_ids_kernel_block_mismatch():
|
||||
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks)
|
||||
vs SSM (logical blocks) when ratio > 1."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorWorker,
|
||||
)
|
||||
|
||||
worker = object.__new__(NixlConnectorWorker)
|
||||
|
||||
ratio = 4
|
||||
logical_blocks = 100
|
||||
num_blocks = logical_blocks * ratio # 400 kernel blocks
|
||||
engine_id = "test-engine"
|
||||
worker.num_regions = 2
|
||||
worker.dst_num_blocks = {engine_id: num_blocks}
|
||||
worker._has_mamba = True
|
||||
worker._is_mamba_group = [False, True]
|
||||
worker._physical_blocks_per_logical_kv_block = ratio
|
||||
worker.num_descs = 2 * num_blocks # 800
|
||||
|
||||
fa_blocks = [3, 7] # kernel-level block IDs
|
||||
ssm_blocks = [1, 2] # logical block IDs
|
||||
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))
|
||||
|
||||
# FA group: stride=num_blocks=400, offset=0
|
||||
# region0: [3, 7], region1: [403, 407]
|
||||
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
|
||||
# region0: [801, 802], region1: [901, 902]
|
||||
expected = [3, 7, 403, 407, 801, 802, 901, 902]
|
||||
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_nixl_metadata_hybrid_ssm_block_ids():
|
||||
"""Test NixlConnectorMetadata correctly stores block IDs for FA + SSM
|
||||
groups with different block counts (kernel mismatch active)."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorMetadata,
|
||||
)
|
||||
|
||||
metadata = NixlConnectorMetadata()
|
||||
|
||||
# FA: 8 kernel blocks (2 logical * ratio=4), SSM: 2 logical blocks
|
||||
fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
ssm_blocks = [0, 1]
|
||||
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id="test-req-hybrid",
|
||||
local_block_ids=(fa_blocks, ssm_blocks),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [20, 21]),
|
||||
"remote_engine_id": "remote-engine",
|
||||
"remote_request_id": "prefill-test-req-hybrid",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"tp_size": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert "test-req-hybrid" in metadata.reqs_to_recv
|
||||
req_meta = metadata.reqs_to_recv["test-req-hybrid"]
|
||||
|
||||
# Verify local block IDs: different lengths per group
|
||||
assert len(req_meta.local_block_ids) == 2
|
||||
assert list(req_meta.local_block_ids[0]) == fa_blocks
|
||||
assert list(req_meta.local_block_ids[1]) == ssm_blocks
|
||||
assert len(req_meta.local_block_ids[0]) != len(req_meta.local_block_ids[1])
|
||||
|
||||
# Verify remote block IDs: same asymmetry preserved
|
||||
assert req_meta.remote is not None
|
||||
assert len(req_meta.remote.block_ids) == 2
|
||||
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
|
||||
assert list(req_meta.remote.block_ids[1]) == [20, 21]
|
||||
assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1])
|
||||
|
||||
@@ -16,10 +16,12 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.kv_cache_interface import MambaSpec
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -328,22 +330,26 @@ class TpKVTopology:
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
attn_backends: list[type[AttentionBackend]]
|
||||
engine_id: EngineId
|
||||
remote_block_size: dict[EngineId, int]
|
||||
tensor_shape: torch.Size | None = None
|
||||
is_mamba: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
_MOCK_BLOCK_SIZE = 16
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
|
||||
)
|
||||
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
|
||||
attn_backend = self.attn_backends[0]
|
||||
if not self.is_mamba:
|
||||
_MOCK_BLOCK_SIZE = 16
|
||||
kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
|
||||
)
|
||||
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
# Hybrid SSM models assume a single blocks_first layout
|
||||
self._is_kv_layout_blocks_first = self.is_mamba or (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
@@ -360,7 +366,7 @@ class TpKVTopology:
|
||||
_MOCK_NUM_LAYERS = 80
|
||||
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=self._cross_layers_blocks
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
@@ -483,6 +489,30 @@ class TpKVTopology:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
|
||||
def get_transfer_cache_regions(
|
||||
self, cache: torch.Tensor, layer_spec: "KVCacheSpec"
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
"""Return the cache tensor(s) to register as NIXL memory regions,
|
||||
also accounting for hybrid SSM models specificities.
|
||||
"""
|
||||
if isinstance(layer_spec, MambaSpec):
|
||||
# Register the whole kv cache shared tensor, including SSM/Conv. This is
|
||||
# similar to FI with the difference that SSM/Conv have different sizes
|
||||
conv, ssm = cache
|
||||
return [conv]
|
||||
|
||||
# Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`.
|
||||
if self.is_mamba and cache.shape[0] == 2:
|
||||
# When MAMBA is present, all backends are blocks first, so that blocks
|
||||
# can be shared between attention layers and mamba layers. Runner
|
||||
# `_update_hybrid_attention_mamba_layout` already adjusted strides
|
||||
# for FlashAttn-like backends so its num_blocks first.
|
||||
# Swap [2<>num_blocks] dims to get required layout for hybrid SSM.
|
||||
cache = cache.transpose(0, 1)
|
||||
|
||||
# Regular case: backends like FA register K/V in separate regions
|
||||
return cache if self.split_k_and_v else [cache]
|
||||
|
||||
|
||||
def get_current_attn_backends(
|
||||
vllm_config: VllmConfig, layer_names: list[str] | None = None
|
||||
|
||||
@@ -564,7 +564,7 @@ class MooncakeConnectorWorker:
|
||||
remote_block_size=self._block_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
attn_backends=[backend],
|
||||
)
|
||||
|
||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||
|
||||
@@ -59,7 +59,12 @@ from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
MambaSpec,
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.utils import select_common_block_size
|
||||
|
||||
@@ -159,6 +164,7 @@ class NixlAgentMetadata:
|
||||
block_lens: list[int]
|
||||
kv_cache_layout: str
|
||||
block_size: int
|
||||
ssm_sizes: tuple[int, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -310,6 +316,15 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
if any(
|
||||
[
|
||||
isinstance(group.kv_cache_spec, MambaSpec)
|
||||
for group in self.kv_cache_config.kv_cache_groups
|
||||
]
|
||||
):
|
||||
# Hybrid SSM models do not yet support cross-layer layout
|
||||
return False
|
||||
|
||||
backend = get_current_attn_backend(self._vllm_config)
|
||||
if backend.get_name() not in (
|
||||
"FLASH_ATTN",
|
||||
@@ -335,12 +350,9 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
if isinstance(group.kv_cache_spec, MambaSpec):
|
||||
raise ValueError("NixlConnector does not support Mamba models.")
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
@@ -434,11 +446,7 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
|
||||
cross_layer_name = "ALL_LAYERS"
|
||||
kv_caches = {cross_layer_name: kv_cache}
|
||||
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
assert self.connector_worker is not None
|
||||
@@ -962,6 +970,40 @@ class NixlConnectorWorker:
|
||||
)
|
||||
)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self._layer_specs = {
|
||||
layer: group.kv_cache_spec
|
||||
for group in kv_cache_config.kv_cache_groups
|
||||
for layer in group.layer_names
|
||||
}
|
||||
self.hma_group_size = len(kv_cache_config.kv_cache_tensors)
|
||||
|
||||
# Mamba metadata
|
||||
self._is_mamba_group = [
|
||||
isinstance(group.kv_cache_spec, MambaSpec)
|
||||
for group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
mamba_ssm_size = (0, 0)
|
||||
self._has_mamba = any(self._is_mamba_group)
|
||||
if self._has_mamba:
|
||||
assert self._is_hma_required
|
||||
mamba_spec = next(
|
||||
spec
|
||||
for spec in self._layer_specs.values()
|
||||
if isinstance(spec, MambaSpec)
|
||||
)
|
||||
conv_nbytes, ssm_nbytes = (
|
||||
torch.tensor([], dtype=mamba_spec.dtypes[0]).element_size(), # type: ignore[misc]
|
||||
torch.tensor([], dtype=mamba_spec.dtypes[1]).element_size(), # type: ignore[misc]
|
||||
)
|
||||
conv_shape, ssm_shape = (
|
||||
torch.Size(mamba_spec.shapes[0]),
|
||||
torch.Size(mamba_spec.shapes[1]),
|
||||
)
|
||||
mamba_ssm_size = (
|
||||
conv_shape.numel() * conv_nbytes,
|
||||
ssm_shape.numel() * ssm_nbytes,
|
||||
)
|
||||
self._mamba_ssm_size = mamba_ssm_size
|
||||
|
||||
# Agent.
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
@@ -1106,9 +1148,9 @@ class NixlConnectorWorker:
|
||||
|
||||
# Get the attention backend from the first layer
|
||||
# NOTE (NickLucche) models with multiple backends are not supported yet
|
||||
self.attn_backend = get_current_attn_backend(vllm_config)
|
||||
self.attn_backends = get_current_attn_backends(vllm_config)
|
||||
self.backend_name = self.attn_backends[0].get_name()
|
||||
|
||||
self.backend_name = self.attn_backend.get_name()
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
self.host_buffer_kv_cache_layout = self.kv_cache_layout
|
||||
logger.info("Detected attention backend %s", self.backend_name)
|
||||
@@ -1135,6 +1177,8 @@ class NixlConnectorWorker:
|
||||
def _sync_block_size_with_kernel(self) -> None:
|
||||
backends = get_current_attn_backends(self.vllm_config)
|
||||
kernel_block_size = select_common_block_size(self.block_size, backends)
|
||||
# Number of blocks not accounting for kernel block mismatches
|
||||
self._logical_num_blocks = self.num_blocks
|
||||
if self.block_size != kernel_block_size:
|
||||
logger.info_once(
|
||||
"User-specified logical block size (%s) does not match"
|
||||
@@ -1428,9 +1472,19 @@ class NixlConnectorWorker:
|
||||
|
||||
fut.add_done_callback(request_ready)
|
||||
|
||||
def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None:
|
||||
"""Register a cross-layers KV cache tensor with NIXL.
|
||||
|
||||
`use_uniform_kv_cache()` guarantees a single KV cache group whose
|
||||
layers all share the same `AttentionSpec`, so any layer name from
|
||||
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
|
||||
"""
|
||||
first_layer = next(iter(self._layer_specs))
|
||||
# Forwarding a real layer name rather than a synthetic key
|
||||
self.register_kv_caches({first_layer: kv_cache})
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
|
||||
self.kv_topo = TpKVTopology(
|
||||
tp_rank=self.tp_rank,
|
||||
engine_id=self.engine_id,
|
||||
@@ -1438,8 +1492,12 @@ class NixlConnectorWorker:
|
||||
remote_block_size=self._block_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=self.attn_backend,
|
||||
tensor_shape=next(iter(kv_caches.values())).shape,
|
||||
attn_backends=self.attn_backends,
|
||||
# SSM States come in tuples (ssm, conv)
|
||||
tensor_shape=next(iter(kv_caches.values())).shape
|
||||
if not self._has_mamba
|
||||
else None,
|
||||
is_mamba=self._has_mamba,
|
||||
)
|
||||
self.compat_hash = compute_nixl_compatibility_hash(
|
||||
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
|
||||
@@ -1481,12 +1539,50 @@ class NixlConnectorWorker:
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
tensor_size_bytes = None
|
||||
|
||||
# Enable different block lengths for different layers when MLA is used.
|
||||
# Enable different block lengths for different layers *only* when MLA is used.
|
||||
# This is not used for SSM layers, which use the counterpart `mamba_ssm_size`.
|
||||
self.block_len_per_layer = list[int]()
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = (
|
||||
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
|
||||
# NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to
|
||||
# that of FI, with block laid out as in `get_backend_aware_kv_block_len`.
|
||||
# However, physical page_size may differ when kernel requires a specific
|
||||
# block size. This leads to SSM and FA layers having different num_blocks.
|
||||
# `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this.
|
||||
layer_spec = self._layer_specs[layer_name]
|
||||
if isinstance(layer_spec, UniformTypeKVCacheSpecs):
|
||||
# MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs
|
||||
layer_spec = layer_spec.kv_cache_specs[layer_name]
|
||||
cache_list = self.kv_topo.get_transfer_cache_regions(
|
||||
cache_or_caches, layer_spec
|
||||
)
|
||||
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
|
||||
# the page_size assuming constant `self._logical_num_blocks`.
|
||||
physical_page_size = (
|
||||
layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, MambaSpec)
|
||||
else layer_spec.page_size_bytes
|
||||
// self._physical_blocks_per_logical_kv_block
|
||||
)
|
||||
# For when registering multiple tensors eg K/V in separate regions.
|
||||
physical_page_size = physical_page_size // len(cache_list)
|
||||
if self.kv_topo._cross_layers_blocks:
|
||||
# When cross-layers blocks are used, multiply by number of layers
|
||||
physical_page_size = physical_page_size * len(
|
||||
self.kv_cache_config.kv_cache_tensors
|
||||
)
|
||||
num_blocks = (
|
||||
self._logical_num_blocks
|
||||
if isinstance(layer_spec, MambaSpec)
|
||||
else self.num_blocks
|
||||
)
|
||||
# `page_size` accounts for physical blocks, st KVCache is always
|
||||
# [`num_blocks` * `page_size`]
|
||||
curr_tensor_size_bytes = num_blocks * physical_page_size
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
|
||||
# TODO (NickLucche) we could eventually unify how we handle FA/FI regions,
|
||||
# registering a single tensor for both K/V and splitting logically like FI.
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
if base_addr in seen_base_addresses:
|
||||
@@ -1494,27 +1590,27 @@ class NixlConnectorWorker:
|
||||
# across groups. This results in skipping all tensors but the ones
|
||||
# pointed to by group0. Also, generally we will have more blocks
|
||||
# per tensor but fewer regions.
|
||||
logger.debug("Skipping %s because it's already seen", layer_name)
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
"Registering layer %s with cache shape: %s", layer_name, cache.shape
|
||||
)
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.numel() * cache.element_size()
|
||||
# Only record non-Mamba page sizes.
|
||||
if isinstance(layer_spec, MambaSpec):
|
||||
self.block_len_per_layer.append(
|
||||
physical_page_size // self._physical_blocks_per_logical_kv_block
|
||||
)
|
||||
else:
|
||||
self.block_len_per_layer.append(physical_page_size)
|
||||
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
|
||||
assert cache.shape[0] == self.num_blocks, (
|
||||
assert cache.shape[0] == num_blocks, (
|
||||
"All kv cache tensors must have the same number of blocks"
|
||||
)
|
||||
|
||||
self.block_len_per_layer.append(
|
||||
curr_tensor_size_bytes // self.num_blocks
|
||||
)
|
||||
|
||||
if not self.use_mla:
|
||||
# Different kv cache shape is not supported by HeteroTP
|
||||
# Different kv cache shape is not supported by HeteroTP.
|
||||
# This must also hold true for Mamba-like models.
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
@@ -1533,6 +1629,21 @@ class NixlConnectorWorker:
|
||||
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
|
||||
self.num_regions = len(caches_data)
|
||||
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# NOTE (NickLucche) When FlashInfer is used, memory is registered
|
||||
# with joint KV for each block. This minimizes the overhead in
|
||||
# registerMem allowing faster descs queries. In order to be able to
|
||||
# split on kv_heads dim as required by heterogeneous TP, one must
|
||||
# be able to index K/V separately. Hence we double the number
|
||||
# of 'virtual' regions here and halve `block_len` below.
|
||||
# Similarly for Mamba layers, we register SSM+Conv as a single region and
|
||||
# then duplicate it logically to be able to index SSM/Conv separately.
|
||||
self.num_regions *= 2
|
||||
|
||||
# TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to
|
||||
# support heterogeneous TP.
|
||||
self.num_descs = self.num_regions * self.num_blocks
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type)
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
|
||||
@@ -1542,17 +1653,21 @@ class NixlConnectorWorker:
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# NOTE (NickLucche) When FlashInfer is used, memory is registered
|
||||
# with joint KV for each block. This minimizes the overhead in
|
||||
# registerMem allowing faster descs queries. In order to be able to
|
||||
# split on kv_heads dim as required by heterogeneous TP, one must
|
||||
# be able to index K/V separately. Hence we double the number
|
||||
# of 'virtual' regions here and halve `block_len` below.
|
||||
self.num_regions *= 2
|
||||
if self._has_mamba:
|
||||
logger.info(
|
||||
"Hybrid SSM registration: num_blocks=%s, "
|
||||
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
|
||||
"num_descs=%s, mamba_ssm_size=%s, block_len_per_layer=%s",
|
||||
self.num_blocks,
|
||||
self._logical_num_blocks,
|
||||
self._physical_blocks_per_logical_kv_block,
|
||||
self.num_regions,
|
||||
self.num_descs,
|
||||
self._mamba_ssm_size,
|
||||
set(self.block_len_per_layer),
|
||||
)
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
self.seen_base_addresses = seen_base_addresses
|
||||
self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
|
||||
self.register_local_xfer_handler(self.block_size)
|
||||
)
|
||||
@@ -1569,6 +1684,7 @@ class NixlConnectorWorker:
|
||||
if not self.use_host_buffer
|
||||
else self.host_buffer_kv_cache_layout,
|
||||
block_size=self.block_size,
|
||||
ssm_sizes=self._mamba_ssm_size,
|
||||
)
|
||||
# Wrap metadata in payload with hash for defensive decoding
|
||||
assert self.compat_hash is not None
|
||||
@@ -1594,40 +1710,65 @@ class NixlConnectorWorker:
|
||||
data copy correctness.
|
||||
"""
|
||||
assert self.kv_topo is not None
|
||||
kv_topo = self.kv_topo
|
||||
|
||||
block_size_ratio = self.block_size // block_size
|
||||
blocks_data = []
|
||||
for i, base_addr in enumerate(self.seen_base_addresses):
|
||||
# The new block_len is using prefill block_len;
|
||||
# and num_blocks is multiple with N
|
||||
kv_block_len = (
|
||||
self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio
|
||||
)
|
||||
block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio
|
||||
num_blocks = self.num_blocks * block_size_ratio
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * block_len_per_layer
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||
blocks_data: list[tuple[int, int, int]] = []
|
||||
local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank]
|
||||
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool):
|
||||
for i, base_addr in enumerate(local_base_addresses):
|
||||
# The new block_len is using prefill block_len;
|
||||
# and num_blocks is multiple with N
|
||||
kv_block_len = (
|
||||
self.get_backend_aware_kv_block_len(
|
||||
layer_idx=i, first_split=True, mamba_view=mamba
|
||||
)
|
||||
// block_size_ratio
|
||||
)
|
||||
# Jump one page_size, but ssm page_size may be bigger when kernel
|
||||
# locks block size to a specific value.
|
||||
block_len_per_layer = (
|
||||
self.block_len_per_layer[i]
|
||||
// block_size_ratio
|
||||
* (1 if not mamba else self._physical_blocks_per_logical_kv_block)
|
||||
)
|
||||
num_blocks = self._logical_num_blocks if mamba else self.num_blocks
|
||||
num_blocks = num_blocks * block_size_ratio
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * block_len_per_layer
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, kv_block_len, self.device_id))
|
||||
logger.debug(
|
||||
"Created %s blocks for src engine %s and rank %s on device id %s",
|
||||
len(blocks_data),
|
||||
self.engine_id,
|
||||
self.tp_rank,
|
||||
self.device_id,
|
||||
)
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||
|
||||
if kv_topo.is_kv_layout_blocks_first:
|
||||
second_split = self.get_backend_aware_kv_block_len(
|
||||
layer_idx=i, first_split=False, mamba_view=mamba
|
||||
)
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * block_len_per_layer
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, second_split, self.device_id))
|
||||
logger.debug(
|
||||
"Created %s blocks for src engine %s and rank %s on device id %s",
|
||||
len(blocks_data),
|
||||
self.engine_id,
|
||||
self.tp_rank,
|
||||
self.device_id,
|
||||
)
|
||||
|
||||
register_blocks(blocks_data, mamba=False)
|
||||
if self._has_mamba:
|
||||
assert self.num_descs == len(blocks_data)
|
||||
logger.debug(
|
||||
"Registering additional %s local Mamba blocks", len(blocks_data)
|
||||
)
|
||||
register_blocks(blocks_data, mamba=True)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
@@ -1708,7 +1849,8 @@ class NixlConnectorWorker:
|
||||
# local origin:| 0| 1| 8| 12|
|
||||
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
|
||||
assert self.kv_topo is not None
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
|
||||
kv_topo = self.kv_topo
|
||||
block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id)
|
||||
|
||||
if engine_id not in self.dst_num_blocks:
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
@@ -1768,48 +1910,86 @@ class NixlConnectorWorker:
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
# Read our whole local region size from remote.
|
||||
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
remote_kv_block_len = local_block_len // block_size_ratio
|
||||
if block_size_ratio > 1:
|
||||
# using remote kv_block_len as transfer unit
|
||||
local_block_len = remote_kv_block_len
|
||||
def register_remote_blocks(
|
||||
blocks_data: list[tuple[int, int, int]], mamba: bool
|
||||
):
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
# Read our whole local region size from remote.
|
||||
local_block_len = self.get_backend_aware_kv_block_len(
|
||||
layer_idx=i, first_split=True, mamba_view=mamba
|
||||
)
|
||||
remote_kv_block_len = local_block_len // block_size_ratio
|
||||
if block_size_ratio > 1:
|
||||
# using remote kv_block_len as transfer unit
|
||||
local_block_len = remote_kv_block_len
|
||||
|
||||
if tp_ratio < 0 and not self.use_mla:
|
||||
# Remote tp is bigger: read a chunk of local region from remote
|
||||
local_block_len = local_block_len // (-tp_ratio)
|
||||
rank_offset = (
|
||||
self.tp_rank % tp_ratio * remote_kv_block_len
|
||||
if indexes_into_remote
|
||||
else 0
|
||||
)
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
|
||||
if tp_ratio < 0 and not self.use_mla:
|
||||
# Remote tp is bigger: read a chunk of local region from remote
|
||||
local_block_len = local_block_len // (-tp_ratio)
|
||||
rank_offset = (
|
||||
self.tp_rank % tp_ratio * remote_kv_block_len
|
||||
if indexes_into_remote
|
||||
else 0
|
||||
)
|
||||
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
# Assume same num_blocks for mamba and fa
|
||||
num_blocks = (
|
||||
nixl_agent_meta.num_blocks
|
||||
if not mamba
|
||||
else nixl_agent_meta.num_blocks
|
||||
// self._physical_blocks_per_logical_kv_block
|
||||
)
|
||||
page_size = nixl_agent_meta.block_lens[i] * (
|
||||
1 if not mamba else self._physical_blocks_per_logical_kv_block
|
||||
)
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * page_size
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(v_addr, local_block_len, nixl_agent_meta.device_id)
|
||||
(addr, local_block_len, nixl_agent_meta.device_id)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
||||
len(blocks_data),
|
||||
engine_id,
|
||||
remote_tp_rank,
|
||||
self.tp_rank,
|
||||
)
|
||||
if kv_topo.is_kv_layout_blocks_first:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
second_split = self.get_backend_aware_kv_block_len(
|
||||
layer_idx=i, first_split=False, mamba_view=mamba
|
||||
)
|
||||
# Apply the same scaling as local_block_len above for when we read
|
||||
# a chunk of local V from `tp_ratio` separate remote workers.
|
||||
if tp_ratio < 0 and not self.use_mla:
|
||||
second_split = second_split // (-tp_ratio)
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * page_size
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# Hop over the first split of remote page: either K or Conv.
|
||||
if mamba:
|
||||
v_addr = addr + nixl_agent_meta.ssm_sizes[0]
|
||||
else:
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
blocks_data.append(
|
||||
(v_addr, second_split, nixl_agent_meta.device_id)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s"
|
||||
" with remote rank %s and local rank %s",
|
||||
len(blocks_data),
|
||||
engine_id,
|
||||
remote_tp_rank,
|
||||
self.tp_rank,
|
||||
)
|
||||
|
||||
register_remote_blocks(blocks_data, mamba=False)
|
||||
if self._has_mamba:
|
||||
# Create extra descs for the Mamba "view" of the same KV cache tensors.
|
||||
logger.debug(
|
||||
"Registering additional %s remote Mamba blocks", len(blocks_data)
|
||||
)
|
||||
register_remote_blocks(blocks_data, mamba=True)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
@@ -1849,6 +2029,9 @@ class NixlConnectorWorker:
|
||||
assert block_size_ratio == 1, (
|
||||
"HMA does not support different remote block size yet"
|
||||
)
|
||||
# Mamba additional constraints
|
||||
if self._has_mamba:
|
||||
assert tp_ratio == 1, "Mamba does not support heterogeneous TP yet"
|
||||
|
||||
kv_cache_layout = (
|
||||
self.kv_cache_layout
|
||||
@@ -2495,6 +2678,7 @@ class NixlConnectorWorker:
|
||||
A single flattened array is returned for all groups anyway.
|
||||
"""
|
||||
region_ids = np.arange(self.num_regions)
|
||||
|
||||
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
|
||||
# layers from different groups share the same kv tensor.
|
||||
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
|
||||
@@ -2505,11 +2689,33 @@ class NixlConnectorWorker:
|
||||
if block_size_ratio is not None:
|
||||
num_blocks = int(num_blocks * block_size_ratio)
|
||||
|
||||
# Compute the desc ids for each block.
|
||||
# Compute desc ids per group using the right stride: FA descs have
|
||||
# num_blocks entries per region (kernel granularity), SSM descs have
|
||||
# logical_blocks entries per region (no kernel splitting).
|
||||
region_ids = region_ids[:, None]
|
||||
block_ids = np.concatenate(block_ids)[None, :]
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
if not self._has_mamba:
|
||||
block_ids = np.concatenate(block_ids)[None, :]
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
else:
|
||||
# NOTE (NickLucche) SSM and Attention blocks regions can be exchanged
|
||||
# arbitrarily by manager. Therefore, descs are duplicated for SSM and
|
||||
# Attention like so:
|
||||
# desc_handle->[descs_fa (all regions) | descs_ssm (all regions)].
|
||||
# This is like having two "low-level views" of the same storage.
|
||||
# `num_fa_descs` offset must be computed per-engine since P and D can
|
||||
# have different num_blocks (and thus different FA descs counts).
|
||||
ratio = self._physical_blocks_per_logical_kv_block
|
||||
# SSM may register fewer num_blocks than FA
|
||||
logical_blocks = num_blocks // ratio
|
||||
num_fa_descs = self.num_regions * num_blocks
|
||||
all_descs = []
|
||||
for i, group in enumerate(block_ids):
|
||||
stride = logical_blocks if self._is_mamba_group[i] else num_blocks
|
||||
group_arr = np.asarray(group)[None, :]
|
||||
offset = num_fa_descs if self._is_mamba_group[i] else 0
|
||||
all_descs.append((region_ids * stride + group_arr + offset).flatten())
|
||||
return np.concatenate(all_descs)
|
||||
|
||||
def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds:
|
||||
"""
|
||||
@@ -2523,16 +2729,22 @@ class NixlConnectorWorker:
|
||||
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
# Mamba blocks have no logical<>physical discrepancy
|
||||
group_specs = self.kv_cache_config.kv_cache_groups
|
||||
return [
|
||||
BlockTable.map_to_kernel_blocks(
|
||||
np.array(group),
|
||||
self._physical_blocks_per_logical_kv_block,
|
||||
block_arange,
|
||||
).tolist()
|
||||
for group in block_ids
|
||||
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec)
|
||||
else group
|
||||
for i, group in enumerate(block_ids)
|
||||
]
|
||||
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
|
||||
def get_backend_aware_kv_block_len(
|
||||
self, layer_idx: int, first_split: bool = True, mamba_view: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
@@ -2540,11 +2752,38 @@ class NixlConnectorWorker:
|
||||
block, as K and V are in separate regions.
|
||||
For FlashInfer, this is half the length of the whole block, as K and V
|
||||
share the same region.
|
||||
Similarly, for SSM-based models, state and conv are interleaved, but crucially
|
||||
the their size differs.
|
||||
Reference diagram:
|
||||
KVCacheTensor (Shared)
|
||||
/ \
|
||||
/ \
|
||||
/ \
|
||||
Attention (FlashInfer) View Mamba View
|
||||
| |
|
||||
| |
|
||||
+-------------------+ +-------------------+
|
||||
| KVCacheTensor | | KVCacheTensor |
|
||||
| | | |
|
||||
|<----- page ------>| |<----- page ------->|
|
||||
| size | | size |
|
||||
| Key 0 | Val 0 | |Conv 0 | SSM 0 |
|
||||
| Key 1 | Val 1 | |Conv 1 | SSM 1 |
|
||||
| ... | ... | | ... | ... |
|
||||
| Key N-2 | Val N-2 | |Conv N-2| SSM N-2 |
|
||||
| Key N-1 | Val N-1 | |Conv N-1| SSM N-1 |
|
||||
+-------------------+ +--------------------+
|
||||
|1st_split-2nd_split| |1st_split-2nd_split |
|
||||
"""
|
||||
assert self.kv_topo is not None
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# For indexing only half (either just the K or V part).
|
||||
block_len = self.block_len_per_layer[layer_idx] // 2
|
||||
if mamba_view:
|
||||
# NOTE (NickLucche) Mamba Opt: this is already skipping the padding so
|
||||
# we're only transferring the minimum required bytes.
|
||||
block_len = self._mamba_ssm_size[not first_split]
|
||||
else:
|
||||
block_len = self.block_len_per_layer[layer_idx] // 2
|
||||
else:
|
||||
block_len = self.block_len_per_layer[layer_idx]
|
||||
return block_len
|
||||
|
||||
Reference in New Issue
Block a user