[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-07-15 13:04:35 +02:00
committed by GitHub
parent c586b55667
commit 3534c39a20
14 changed files with 441 additions and 353 deletions

View File

@@ -61,14 +61,6 @@ V1_SUPPORTED_MODELS = [
"tiiuae/Falcon-H1-0.5B-Base",
]
ATTN_BLOCK_SIZES = {
"ibm-ai-platform/Bamba-9B-v1": 528,
"Zyphra/Zamba2-1.2B-instruct": 80,
"nvidia/Nemotron-H-8B-Base-8K": 528,
"ibm-granite/granite-4.0-tiny-preview": 400,
"tiiuae/Falcon-H1-0.5B-Base": 800,
}
# Avoid OOM
MAX_NUM_SEQS = 4
@@ -105,11 +97,6 @@ def test_models(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
block_size = ATTN_BLOCK_SIZES[model]
else:
block_size = 16
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
@@ -118,8 +105,7 @@ def test_models(
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False,
block_size=block_size) as vllm_model:
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else: