[Kernel] Mamba support different layout for Conv state (#37416)
This commit is contained in:
@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4
|
||||
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
|
||||
|
||||
|
||||
def _set_conv_state_layout(monkeypatch, layout: str) -> None:
|
||||
"""Set conv state layout env var and clear cache to pick up new value."""
|
||||
from vllm.model_executor.layers.mamba import mamba_utils
|
||||
|
||||
monkeypatch.setenv("VLLM_SSM_CONV_STATE_LAYOUT", layout)
|
||||
mamba_utils.get_conv_state_layout.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@@ -102,12 +110,15 @@ def test_models(
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
|
||||
def test_batching(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
conv_state_layout: str,
|
||||
) -> None:
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
@@ -116,6 +127,8 @@ def test_batching(
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
_set_conv_state_layout(monkeypatch, conv_state_layout)
|
||||
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
@@ -138,11 +151,14 @@ def test_batching(
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
|
||||
def test_chunked_prefill_with_parallel_sampling(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
conv_state_layout: str,
|
||||
) -> None:
|
||||
"""
|
||||
Tests chunked prefill in conjunction with n > 1.
|
||||
@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling(
|
||||
decoding steps inside a chunked prefill forward pass
|
||||
(where we have both prefill and decode together)
|
||||
"""
|
||||
_set_conv_state_layout(monkeypatch, conv_state_layout)
|
||||
|
||||
sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
|
||||
with vllm_runner(
|
||||
model,
|
||||
@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling(
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
|
||||
def test_mamba_cache_cg_padding(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
conv_state_layout: str,
|
||||
) -> None:
|
||||
"""
|
||||
This test is for verifying that mamba cache is padded to CG captured
|
||||
batch size. If it's not, a torch RuntimeError will be raised because
|
||||
tensor dimensions aren't compatible.
|
||||
"""
|
||||
_set_conv_state_layout(monkeypatch, conv_state_layout)
|
||||
|
||||
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
|
||||
cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
|
||||
cudagraph_dispatcher.initialize_cudagraph_keys(
|
||||
|
||||
Reference in New Issue
Block a user