[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -61,14 +61,6 @@ V1_SUPPORTED_MODELS = [
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
]
|
||||
|
||||
ATTN_BLOCK_SIZES = {
|
||||
"ibm-ai-platform/Bamba-9B-v1": 528,
|
||||
"Zyphra/Zamba2-1.2B-instruct": 80,
|
||||
"nvidia/Nemotron-H-8B-Base-8K": 528,
|
||||
"ibm-granite/granite-4.0-tiny-preview": 400,
|
||||
"tiiuae/Falcon-H1-0.5B-Base": 800,
|
||||
}
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
@@ -105,11 +97,6 @@ def test_models(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
|
||||
block_size = ATTN_BLOCK_SIZES[model]
|
||||
else:
|
||||
block_size = 16
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
@@ -118,8 +105,7 @@ def test_models(
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False,
|
||||
block_size=block_size) as vllm_model:
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user