[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation (#37310)

This commit is contained in:
zhanqiuhu
2026-03-19 03:22:00 -04:00
committed by GitHub
parent b21d384304
commit d49f273144
5 changed files with 263 additions and 13 deletions

View File

@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector(
def make_kv_cache_config(
block_size: int,
hma_enabled: bool = False,
swa_enabled: bool = False,
mamba_enabled: bool = False,
sw_size: int = 128,
num_blocks: int = 100,
) -> KVCacheConfig:
@@ -438,7 +440,7 @@ def make_kv_cache_config(
),
)
]
if hma_enabled:
if swa_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["layer1", "layer3"],
@@ -451,6 +453,32 @@ def make_kv_cache_config(
),
)
)
if mamba_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["mamba0", "mamba1"],
MambaSpec(
block_size=block_size,
shapes=((16,), (16,)),
dtypes=(torch.float16,),
),
)
)
return KVCacheConfig(
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)
def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).
Only sets the two flags needed by the N-1 prefill logic.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
sched = object.__new__(NixlConnectorScheduler)
sched._has_mamba = has_mamba
sched._is_hma_required = is_hma_required
return sched