[Model] Add PLaMo2 (#14323)
Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Signed-off-by: shemmi <shemmi@preferred.jp> Co-authored-by: Kento Nozawa <nzw0301@preferred.jp> Co-authored-by: Hiroaki Mikami <mhiroaki@preferred.jp> Co-authored-by: Calvin Metzger <metzger@preferred.jp>
This commit is contained in:
@@ -9,9 +9,15 @@ from vllm.sampling_params import SamplingParams
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
# This test is for the hybrid models
|
||||
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
|
||||
MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
|
||||
"pfnet/plamo-2-1b"
|
||||
]
|
||||
# Bamba at Fp32 is too big for the CI (L4 GPU).
|
||||
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
|
||||
# Note: Running Plamo2 in transformers implementation requires to install
|
||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||
# not compatible with pip-compile.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@@ -25,21 +31,11 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
# numeric error produces different generation
|
||||
if "Bamba" in model:
|
||||
example_prompts.pop(3)
|
||||
|
||||
model_kwargs = {
|
||||
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}
|
||||
if "Zamba2" in model:
|
||||
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||
# installed
|
||||
model_kwargs = {}
|
||||
|
||||
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
@@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
|
||||
# correctly for n > 1 decoding steps inside a
|
||||
# chunked prefill forward pass (where we have both prefills
|
||||
# and decoding together )
|
||||
|
||||
if 'plamo-2' in model:
|
||||
dtype = "float" # use a different dtype for plamo
|
||||
|
||||
sampling_params = SamplingParams(n=3,
|
||||
temperature=1,
|
||||
seed=0,
|
||||
@@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
||||
example_prompts.pop(3)
|
||||
example_prompts.pop(2)
|
||||
dtype = "half" # use a different dtype for Bamba
|
||||
|
||||
elif "Zamba2" in model:
|
||||
example_prompts.pop(7)
|
||||
dtype = "half"
|
||||
elif "plamo-2-1b" in model:
|
||||
example_prompts.pop(7)
|
||||
|
||||
model_kwargs = {
|
||||
"use_mamba_kernels": False, # mamba kernels are not installed so HF
|
||||
# don't use them
|
||||
}
|
||||
if "Zamba2" in model:
|
||||
# Zamba2 HF implementation automatically checks if mamba kernels are
|
||||
# installed
|
||||
model_kwargs = {}
|
||||
|
||||
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
@@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
vllm_config = EngineArgs(model=model).create_engine_config()
|
||||
vllm_config = EngineArgs(model=model,
|
||||
trust_remote_code=True).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
Reference in New Issue
Block a user