[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation (#37310)
This commit is contained in:
@@ -2007,7 +2007,7 @@ def test_transfer_failure_logging(
|
||||
connector = NixlConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
make_kv_cache_config(block_size=16, hma_enabled=enable_hma),
|
||||
make_kv_cache_config(block_size=16, swa_enabled=enable_hma),
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
|
||||
"""Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -14,24 +14,26 @@ from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
create_request,
|
||||
create_vllm_config,
|
||||
make_kv_cache_config,
|
||||
make_nixl_scheduler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"hma_enabled,expected_sw_sizes",
|
||||
"swa_enabled,expected_sw_sizes",
|
||||
[
|
||||
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
|
||||
# SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
|
||||
(True, [0, 128 + 1]),
|
||||
# HMA disabled: only FullAttentionSpec (0)
|
||||
# SWA disabled: only FullAttentionSpec (0)
|
||||
(False, [0]),
|
||||
],
|
||||
)
|
||||
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
|
||||
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
|
||||
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
|
||||
def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes):
|
||||
"""Test sw_sizes is correctly computed based on SWA enabled/disabled."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorScheduler,
|
||||
)
|
||||
@@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
|
||||
vllm_config = create_vllm_config(block_size=block_size)
|
||||
# SW 2048 tokens=>128 blocks
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048
|
||||
block_size=block_size, swa_enabled=swa_enabled, sw_size=2048
|
||||
)
|
||||
|
||||
scheduler = NixlConnectorScheduler(
|
||||
@@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma():
|
||||
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
|
||||
worker._physical_blocks_per_logical_kv_block = 2
|
||||
# FA + SW groups (neither is MambaSpec, so both get expanded)
|
||||
worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True)
|
||||
worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True)
|
||||
|
||||
# Test conversion: FA + SW group
|
||||
logical_block_ids = [[0, 1, 2], [3, 4]]
|
||||
@@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids():
|
||||
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
|
||||
assert list(req_meta.remote.block_ids[1]) == [20, 21]
|
||||
assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1])
|
||||
|
||||
|
||||
# ── Mamba N-1 prefill tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"has_mamba,is_hma_required,expected_count",
|
||||
[
|
||||
(True, True, 9),
|
||||
(False, False, 10),
|
||||
(False, True, 10),
|
||||
],
|
||||
ids=["mamba", "fa_only", "swa_only"],
|
||||
)
|
||||
def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count):
|
||||
"""D-side: Mamba gets N-1 matched tokens, non-Mamba gets N."""
|
||||
sched = make_nixl_scheduler(has_mamba=has_mamba, is_hma_required=is_hma_required)
|
||||
req = create_request(num_tokens=10, do_remote_prefill=True)
|
||||
|
||||
count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
|
||||
assert count == expected_count
|
||||
assert is_async is True
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_mamba_n1_p_side_truncation():
|
||||
"""P-side: Mamba truncates prompt to N-1, sets max_tokens=1.
|
||||
|
||||
Also verifies idempotency (calling again is a no-op) which is
|
||||
needed for preemption safety via the _p_side_truncated guard,
|
||||
and that non-Mamba models skip truncation entirely.
|
||||
"""
|
||||
sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True)
|
||||
req = create_request(num_tokens=10, do_remote_decode=True)
|
||||
req.max_tokens = 128
|
||||
original_len = len(req.prompt_token_ids)
|
||||
|
||||
count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
|
||||
|
||||
assert count == 0
|
||||
assert is_async is False
|
||||
assert len(req.prompt_token_ids) == original_len - 1
|
||||
assert req.num_prompt_tokens == original_len - 1
|
||||
assert req.max_tokens == 1
|
||||
assert req.kv_transfer_params["_p_side_truncated"] is True
|
||||
|
||||
# Idempotency: second call must not truncate further
|
||||
sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
|
||||
assert len(req.prompt_token_ids) == original_len - 1
|
||||
|
||||
# Non-Mamba: truncation is skipped
|
||||
fa_sched = make_nixl_scheduler(has_mamba=False, is_hma_required=False)
|
||||
fa_req = create_request(num_tokens=10, do_remote_decode=True)
|
||||
fa_original = len(fa_req.prompt_token_ids)
|
||||
|
||||
fa_sched.get_num_new_matched_tokens(fa_req, num_computed_tokens=0)
|
||||
assert len(fa_req.prompt_token_ids) == fa_original
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma",
|
||||
[
|
||||
(True, True, True, True),
|
||||
(True, False, False, True),
|
||||
(False, False, False, False),
|
||||
],
|
||||
ids=["fa_swa_mamba", "fa_swa_only", "fa_only"],
|
||||
)
|
||||
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
|
||||
def test_has_mamba_init(
|
||||
mock_platform,
|
||||
swa_enabled,
|
||||
mamba_enabled,
|
||||
expected_has_mamba,
|
||||
expected_is_hma,
|
||||
):
|
||||
"""Test _has_mamba / _is_hma_required derived from kv_cache_groups."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorScheduler,
|
||||
)
|
||||
|
||||
mock_platform.device_type = "cpu"
|
||||
|
||||
block_size = 16
|
||||
vllm_config = create_vllm_config(block_size=block_size)
|
||||
# VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config
|
||||
# is set; override so we can test the scheduler's own derivation.
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size,
|
||||
swa_enabled=swa_enabled,
|
||||
mamba_enabled=mamba_enabled,
|
||||
)
|
||||
|
||||
scheduler = NixlConnectorScheduler(
|
||||
vllm_config=vllm_config,
|
||||
engine_id="test-engine",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
assert scheduler._has_mamba is expected_has_mamba
|
||||
assert scheduler._is_hma_required is expected_is_hma
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
@@ -13,6 +18,7 @@ from .utils import (
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
make_kv_cache_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
@@ -579,3 +585,73 @@ def test_cannot_recv():
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
|
||||
def test_p_side_chunked_prefill_mamba(mock_platform):
|
||||
"""P-side integration: Mamba N-1 truncation + chunked prefill completes.
|
||||
|
||||
A 64-token P-side request is truncated to 63 by the N-1 fix, then
|
||||
chunked into two prefill steps (32 + 31) and finishes with
|
||||
LENGTH_CAPPED because max_tokens is set to 1.
|
||||
"""
|
||||
mock_platform.device_type = "cpu"
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_TOKENS = 64
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
max_num_batched_tokens=BATCH_SIZE,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=BLOCK_SIZE,
|
||||
mamba_enabled=True,
|
||||
num_blocks=10000,
|
||||
)
|
||||
|
||||
scheduler = create_scheduler(vllm_config, kv_cache_config=kv_cache_config)
|
||||
|
||||
request = create_request(
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
request.max_tokens = 128
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# ── Step 1: first chunk ──
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(request.prompt_token_ids) == NUM_TOKENS - 1
|
||||
assert request.max_tokens == 1
|
||||
assert scheduler_output.num_scheduled_tokens[request_id] == BATCH_SIZE
|
||||
assert request.num_computed_tokens == BATCH_SIZE
|
||||
|
||||
# Model returns no tokens for intermediate prefill chunk
|
||||
intermediate_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id],
|
||||
req_id_to_index={request.request_id: 0},
|
||||
sampled_token_ids=[[]],
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, intermediate_output)
|
||||
|
||||
# ── Step 2: remaining chunk ──
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
remaining = NUM_TOKENS - 1 - BATCH_SIZE # 31
|
||||
assert scheduler_output.num_scheduled_tokens[request_id] == remaining
|
||||
assert request.num_computed_tokens == NUM_TOKENS - 1
|
||||
|
||||
# Prefill complete: model generates 1 decode token
|
||||
final_output = create_model_runner_output([request])
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output, final_output)
|
||||
|
||||
# max_tokens=1 → request finishes with LENGTH
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0].finish_reason == FinishReason.LENGTH
|
||||
|
||||
@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
MambaSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector(
|
||||
|
||||
def make_kv_cache_config(
|
||||
block_size: int,
|
||||
hma_enabled: bool = False,
|
||||
swa_enabled: bool = False,
|
||||
mamba_enabled: bool = False,
|
||||
sw_size: int = 128,
|
||||
num_blocks: int = 100,
|
||||
) -> KVCacheConfig:
|
||||
@@ -438,7 +440,7 @@ def make_kv_cache_config(
|
||||
),
|
||||
)
|
||||
]
|
||||
if hma_enabled:
|
||||
if swa_enabled:
|
||||
kv_cache_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
["layer1", "layer3"],
|
||||
@@ -451,6 +453,32 @@ def make_kv_cache_config(
|
||||
),
|
||||
)
|
||||
)
|
||||
if mamba_enabled:
|
||||
kv_cache_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
["mamba0", "mamba1"],
|
||||
MambaSpec(
|
||||
block_size=block_size,
|
||||
shapes=((16,), (16,)),
|
||||
dtypes=(torch.float16,),
|
||||
),
|
||||
)
|
||||
)
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
|
||||
)
|
||||
|
||||
|
||||
def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
|
||||
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).
|
||||
|
||||
Only sets the two flags needed by the N-1 prefill logic.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorScheduler,
|
||||
)
|
||||
|
||||
sched = object.__new__(NixlConnectorScheduler)
|
||||
sched._has_mamba = has_mamba
|
||||
sched._is_hma_required = is_hma_required
|
||||
return sched
|
||||
|
||||
@@ -572,6 +572,10 @@ class NixlConnectorScheduler:
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
)
|
||||
self._has_mamba = any(
|
||||
isinstance(g.kv_cache_spec, MambaSpec)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
@@ -717,6 +721,39 @@ class NixlConnectorScheduler:
|
||||
logger.warning("Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
|
||||
|
||||
def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
|
||||
"""D-side only. Returns N-1 for Mamba models since the decoder
|
||||
always recomputes the last token and must start from h(N-1)."""
|
||||
if self._has_mamba and num_prompt_tokens > 1:
|
||||
return num_prompt_tokens - 1
|
||||
return num_prompt_tokens
|
||||
|
||||
def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
|
||||
"""P-side only: drop the last prompt token so the prefiller computes
|
||||
h(N-1) instead of h(N). The decoder recomputes the last token to
|
||||
derive h(N) correctly.
|
||||
|
||||
Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
|
||||
request is preempted and rescheduled."""
|
||||
params = request.kv_transfer_params
|
||||
if (
|
||||
params is not None
|
||||
# Guard against repeated truncation after preemption/reschedule.
|
||||
and not params.get("_p_side_truncated")
|
||||
and request.num_prompt_tokens > 1
|
||||
):
|
||||
if request.prompt_token_ids is not None:
|
||||
request.prompt_token_ids.pop()
|
||||
elif request.prompt_embeds is not None:
|
||||
request.prompt_embeds = request.prompt_embeds[:-1]
|
||||
else:
|
||||
return
|
||||
|
||||
request._all_token_ids.pop()
|
||||
request.num_prompt_tokens -= 1
|
||||
request.max_tokens = 1
|
||||
params["_p_side_truncated"] = True
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
@@ -746,10 +783,14 @@ class NixlConnectorScheduler:
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
token_ids = request.prompt_token_ids or []
|
||||
count = len(token_ids) - num_computed_tokens
|
||||
actual = self._mamba_prefill_token_count(len(token_ids))
|
||||
count = actual - num_computed_tokens
|
||||
if count > 0:
|
||||
return count, True
|
||||
|
||||
if params is not None and params.get("do_remote_decode") and self._has_mamba:
|
||||
self._truncate_mamba_request_for_prefill(request)
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user