[MODEL] Add support for Zamba2 models (#13185)
Signed-off-by: Yury Tokpanov <yury@zyphra.com> Signed-off-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
# This test is for the hybrid models
|
||||
MODELS = ["ai21labs/Jamba-tiny-dev"]
|
||||
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
|
||||
# Bamba at Fp32 is too big for the CI (L4 GPU).
|
||||
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
|
||||
|
||||
@@ -27,17 +27,19 @@ def test_models(
|
||||
) -> None:
|
||||
|
||||
# numeric error produces different generation
|
||||
if 'Bamba' in model:
|
||||
if "Bamba" in model:
|
||||
example_prompts.pop(3)
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
model_kwargs={
|
||||
"use_mamba_kernels":
|
||||
False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}) as hf_model:
|
||||
model_kwargs = {
|
||||
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}
|
||||
if "Zamba2" in model:
|
||||
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||
# installed
|
||||
model_kwargs = {}
|
||||
|
||||
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
@@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
|
||||
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# numeric error during prefill chucking produces different generation
|
||||
# numeric error during prefill chunking produces different generation
|
||||
# compared to w/o prefill chunking for those examples, removed them for now
|
||||
if 'Jamba' in model:
|
||||
if "Jamba" in model:
|
||||
example_prompts.pop(7)
|
||||
example_prompts.pop(2)
|
||||
example_prompts.pop(1)
|
||||
elif 'Bamba' in model:
|
||||
elif "Bamba" in model:
|
||||
example_prompts.pop(6)
|
||||
example_prompts.pop(3)
|
||||
example_prompts.pop(2)
|
||||
dtype = "half" # use a different dtype for Bamba
|
||||
elif "Zamba2" in model:
|
||||
example_prompts.pop(7)
|
||||
dtype = "half"
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
model_kwargs={
|
||||
"use_mamba_kernels":
|
||||
False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}) as hf_model:
|
||||
model_kwargs = {
|
||||
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}
|
||||
if "Zamba2" in model:
|
||||
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||
# installed
|
||||
model_kwargs = {}
|
||||
|
||||
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
|
||||
Reference in New Issue
Block a user