[Kernel] Mamba support different layout for Conv state (#37416)

This commit is contained in:
Nicolò Lucchesi
2026-04-03 01:50:09 +02:00
committed by GitHub
parent bb39382b2b
commit 66e86f1dbd
11 changed files with 169 additions and 39 deletions

View File

@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
def _set_conv_state_layout(monkeypatch, layout: str) -> None:
"""Set conv state layout env var and clear cache to pick up new value."""
from vllm.model_executor.layers.mamba import mamba_utils
monkeypatch.setenv("VLLM_SSM_CONV_STATE_LAYOUT", layout)
mamba_utils.get_conv_state_layout.cache_clear()
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@@ -102,12 +110,15 @@ def test_models(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_batching(
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
conv_state_layout: str,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
@@ -116,6 +127,8 @@ def test_batching(
except ValueError:
pass
_set_conv_state_layout(monkeypatch, conv_state_layout)
for_loop_outputs = []
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for prompt in example_prompts:
@@ -138,11 +151,14 @@ def test_batching(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_chunked_prefill_with_parallel_sampling(
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
conv_state_layout: str,
) -> None:
"""
Tests chunked prefill in conjunction with n > 1.
@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling(
decoding steps inside a chunked prefill forward pass
(where we have both prefill and decode together)
"""
_set_conv_state_layout(monkeypatch, conv_state_layout)
sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
with vllm_runner(
model,
@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
conv_state_layout: str,
) -> None:
"""
This test is for verifying that mamba cache is padded to CG captured
batch size. If it's not, a torch RuntimeError will be raised because
tensor dimensions aren't compatible.
"""
_set_conv_state_layout(monkeypatch, conv_state_layout)
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
cudagraph_dispatcher.initialize_cudagraph_keys(