[Model] Add PLaMo2 (#14323)

Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: shemmi <shemmi@preferred.jp>
Co-authored-by: Kento Nozawa <nzw0301@preferred.jp>
Co-authored-by: Hiroaki Mikami <mhiroaki@preferred.jp>
Co-authored-by: Calvin Metzger <metzger@preferred.jp>
This commit is contained in:
Shinichi Hemmi
2025-04-16 11:31:30 +09:00
committed by GitHub
parent fdcb850f14
commit 3badb0213b
9 changed files with 800 additions and 24 deletions

View File

@@ -9,9 +9,15 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
MODELS = [
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
"pfnet/plamo-2-1b"
]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
# Note: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.
@pytest.mark.parametrize("model", MODELS)
@@ -25,21 +31,11 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# numeric error produces different generation
if "Bamba" in model:
example_prompts.pop(3)
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
@@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
# correctly for n > 1 decoding steps inside a
# chunked prefill forward pass (where we have both prefills
# and decoding together )
if 'plamo-2' in model:
dtype = "float" # use a different dtype for plamo
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
@@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
elif "plamo-2-1b" in model:
example_prompts.pop(7)
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,
@@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
vllm_config = EngineArgs(model=model).create_engine_config()
vllm_config = EngineArgs(model=model,
trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])