[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation (#37310)
This commit is contained in:
@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
MambaSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector(
|
||||
|
||||
def make_kv_cache_config(
|
||||
block_size: int,
|
||||
hma_enabled: bool = False,
|
||||
swa_enabled: bool = False,
|
||||
mamba_enabled: bool = False,
|
||||
sw_size: int = 128,
|
||||
num_blocks: int = 100,
|
||||
) -> KVCacheConfig:
|
||||
@@ -438,7 +440,7 @@ def make_kv_cache_config(
|
||||
),
|
||||
)
|
||||
]
|
||||
if hma_enabled:
|
||||
if swa_enabled:
|
||||
kv_cache_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
["layer1", "layer3"],
|
||||
@@ -451,6 +453,32 @@ def make_kv_cache_config(
|
||||
),
|
||||
)
|
||||
)
|
||||
if mamba_enabled:
|
||||
kv_cache_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
["mamba0", "mamba1"],
|
||||
MambaSpec(
|
||||
block_size=block_size,
|
||||
shapes=((16,), (16,)),
|
||||
dtypes=(torch.float16,),
|
||||
),
|
||||
)
|
||||
)
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
|
||||
)
|
||||
|
||||
|
||||
def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
|
||||
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).
|
||||
|
||||
Only sets the two flags needed by the N-1 prefill logic.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorScheduler,
|
||||
)
|
||||
|
||||
sched = object.__new__(NixlConnectorScheduler)
|
||||
sched._has_mamba = has_mamba
|
||||
sched._is_hma_required = is_hma_required
|
||||
return sched
|
||||
|
||||
Reference in New Issue
Block a user