[MODEL] Add support for Zamba2 models (#13185)

Signed-off-by: Yury Tokpanov <yury@zyphra.com>
Signed-off-by: Quentin Anthony <qganthony@yahoo.com>
Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
yury-tokpanov
2025-03-18 08:56:21 -07:00
committed by GitHub
parent 8b793f7ec6
commit 452e8fd968
9 changed files with 1081 additions and 26 deletions

View File

@@ -9,7 +9,7 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev"]
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
@@ -27,17 +27,19 @@ def test_models(
) -> None:
# numeric error produces different generation
if 'Bamba' in model:
if "Bamba" in model:
example_prompts.pop(3)
with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
@@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# numeric error during prefill chucking produces different generation
# numeric error during prefill chunking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now
if 'Jamba' in model:
if "Jamba" in model:
example_prompts.pop(7)
example_prompts.pop(2)
example_prompts.pop(1)
elif 'Bamba' in model:
elif "Bamba" in model:
example_prompts.pop(6)
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,