[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation (#37310)

This commit is contained in:
zhanqiuhu
2026-03-19 03:22:00 -04:00
committed by GitHub
parent b21d384304
commit d49f273144
5 changed files with 263 additions and 13 deletions

View File

@@ -2007,7 +2007,7 @@ def test_transfer_failure_logging(
connector = NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16, hma_enabled=enable_hma),
make_kv_cache_config(block_size=16, swa_enabled=enable_hma),
)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
"""Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill."""
from unittest.mock import patch
@@ -14,24 +14,26 @@ from vllm.v1.core.single_type_kv_cache_manager import (
)
from .utils import (
create_request,
create_vllm_config,
make_kv_cache_config,
make_nixl_scheduler,
)
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"hma_enabled,expected_sw_sizes",
"swa_enabled,expected_sw_sizes",
[
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
# SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(True, [0, 128 + 1]),
# HMA disabled: only FullAttentionSpec (0)
# SWA disabled: only FullAttentionSpec (0)
(False, [0]),
],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on SWA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
@@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
vllm_config = create_vllm_config(block_size=block_size)
# SW 2048 tokens=>128 blocks
kv_cache_config = make_kv_cache_config(
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048
block_size=block_size, swa_enabled=swa_enabled, sw_size=2048
)
scheduler = NixlConnectorScheduler(
@@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma():
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2
# FA + SW groups (neither is MambaSpec, so both get expanded)
worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True)
worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True)
# Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]]
@@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids():
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [20, 21]
assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1])
# ── Mamba N-1 prefill tests ──────────────────────────────────────────────
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"has_mamba,is_hma_required,expected_count",
[
(True, True, 9),
(False, False, 10),
(False, True, 10),
],
ids=["mamba", "fa_only", "swa_only"],
)
def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count):
"""D-side: Mamba gets N-1 matched tokens, non-Mamba gets N."""
sched = make_nixl_scheduler(has_mamba=has_mamba, is_hma_required=is_hma_required)
req = create_request(num_tokens=10, do_remote_prefill=True)
count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert count == expected_count
assert is_async is True
@pytest.mark.cpu_test
def test_mamba_n1_p_side_truncation():
"""P-side: Mamba truncates prompt to N-1, sets max_tokens=1.
Also verifies idempotency (calling again is a no-op) which is
needed for preemption safety via the _p_side_truncated guard,
and that non-Mamba models skip truncation entirely.
"""
sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True)
req = create_request(num_tokens=10, do_remote_decode=True)
req.max_tokens = 128
original_len = len(req.prompt_token_ids)
count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert count == 0
assert is_async is False
assert len(req.prompt_token_ids) == original_len - 1
assert req.num_prompt_tokens == original_len - 1
assert req.max_tokens == 1
assert req.kv_transfer_params["_p_side_truncated"] is True
# Idempotency: second call must not truncate further
sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert len(req.prompt_token_ids) == original_len - 1
# Non-Mamba: truncation is skipped
fa_sched = make_nixl_scheduler(has_mamba=False, is_hma_required=False)
fa_req = create_request(num_tokens=10, do_remote_decode=True)
fa_original = len(fa_req.prompt_token_ids)
fa_sched.get_num_new_matched_tokens(fa_req, num_computed_tokens=0)
assert len(fa_req.prompt_token_ids) == fa_original
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma",
[
(True, True, True, True),
(True, False, False, True),
(False, False, False, False),
],
ids=["fa_swa_mamba", "fa_swa_only", "fa_only"],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_has_mamba_init(
mock_platform,
swa_enabled,
mamba_enabled,
expected_has_mamba,
expected_is_hma,
):
"""Test _has_mamba / _is_hma_required derived from kv_cache_groups."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
mock_platform.device_type = "cpu"
block_size = 16
vllm_config = create_vllm_config(block_size=block_size)
# VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config
# is set; override so we can test the scheduler's own derivation.
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size,
swa_enabled=swa_enabled,
mamba_enabled=mamba_enabled,
)
scheduler = NixlConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler._has_mamba is expected_has_mamba
assert scheduler._is_hma_required is expected_is_hma

View File

@@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from unittest.mock import patch
import pytest
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (
@@ -13,6 +18,7 @@ from .utils import (
create_request,
create_scheduler,
create_vllm_config,
make_kv_cache_config,
)
pytestmark = pytest.mark.cpu_test
@@ -579,3 +585,73 @@ def test_cannot_recv():
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_p_side_chunked_prefill_mamba(mock_platform):
"""P-side integration: Mamba N-1 truncation + chunked prefill completes.
A 64-token P-side request is truncated to 63 by the N-1 fix, then
chunked into two prefill steps (32 + 31) and finishes with
LENGTH_CAPPED because max_tokens is set to 1.
"""
mock_platform.device_type = "cpu"
BATCH_SIZE = 32
NUM_TOKENS = 64
BLOCK_SIZE = 16
vllm_config = create_vllm_config(
max_num_batched_tokens=BATCH_SIZE,
block_size=BLOCK_SIZE,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=BLOCK_SIZE,
mamba_enabled=True,
num_blocks=10000,
)
scheduler = create_scheduler(vllm_config, kv_cache_config=kv_cache_config)
request = create_request(
num_tokens=NUM_TOKENS,
do_remote_decode=True,
block_size=BLOCK_SIZE,
)
request.max_tokens = 128
scheduler.add_request(request)
request_id = request.request_id
# ── Step 1: first chunk ──
scheduler_output = scheduler.schedule()
assert len(request.prompt_token_ids) == NUM_TOKENS - 1
assert request.max_tokens == 1
assert scheduler_output.num_scheduled_tokens[request_id] == BATCH_SIZE
assert request.num_computed_tokens == BATCH_SIZE
# Model returns no tokens for intermediate prefill chunk
intermediate_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[]],
)
scheduler.update_from_output(scheduler_output, intermediate_output)
# ── Step 2: remaining chunk ──
scheduler_output = scheduler.schedule()
remaining = NUM_TOKENS - 1 - BATCH_SIZE # 31
assert scheduler_output.num_scheduled_tokens[request_id] == remaining
assert request.num_computed_tokens == NUM_TOKENS - 1
# Prefill complete: model generates 1 decode token
final_output = create_model_runner_output([request])
engine_core_outputs = scheduler.update_from_output(scheduler_output, final_output)
# max_tokens=1 → request finishes with LENGTH
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
assert outputs[0].finish_reason == FinishReason.LENGTH

View File

@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector(
def make_kv_cache_config(
block_size: int,
hma_enabled: bool = False,
swa_enabled: bool = False,
mamba_enabled: bool = False,
sw_size: int = 128,
num_blocks: int = 100,
) -> KVCacheConfig:
@@ -438,7 +440,7 @@ def make_kv_cache_config(
),
)
]
if hma_enabled:
if swa_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["layer1", "layer3"],
@@ -451,6 +453,32 @@ def make_kv_cache_config(
),
)
)
if mamba_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["mamba0", "mamba1"],
MambaSpec(
block_size=block_size,
shapes=((16,), (16,)),
dtypes=(torch.float16,),
),
)
)
return KVCacheConfig(
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)
def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).
Only sets the two flags needed by the N-1 prefill logic.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
sched = object.__new__(NixlConnectorScheduler)
sched._has_mamba = has_mamba
sched._is_hma_required = is_hma_required
return sched

View File

@@ -572,6 +572,10 @@ class NixlConnectorScheduler:
for g in kv_cache_config.kv_cache_groups
)
)
self._has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec)
for g in kv_cache_config.kv_cache_groups
)
logger.info("Initializing NIXL Scheduler %s", engine_id)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
@@ -717,6 +721,39 @@ class NixlConnectorScheduler:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
"""D-side only. Returns N-1 for Mamba models since the decoder
always recomputes the last token and must start from h(N-1)."""
if self._has_mamba and num_prompt_tokens > 1:
return num_prompt_tokens - 1
return num_prompt_tokens
def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
"""P-side only: drop the last prompt token so the prefiller computes
h(N-1) instead of h(N). The decoder recomputes the last token to
derive h(N) correctly.
Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
request is preempted and rescheduled."""
params = request.kv_transfer_params
if (
params is not None
# Guard against repeated truncation after preemption/reschedule.
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
return
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
@@ -746,10 +783,14 @@ class NixlConnectorScheduler:
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
actual = self._mamba_prefill_token_count(len(token_ids))
count = actual - num_computed_tokens
if count > 0:
return count, True
if params is not None and params.get("do_remote_decode") and self._has_mamba:
self._truncate_mamba_request_for_prefill(request)
# No remote prefill for this request.
return 0, False