[v1] Support mamba2 (#19327)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -17,9 +17,10 @@ SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
|
||||
# doesn't compare vLLM output with HF output.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
# "mistralai/Mamba-Codestral-7B-v0.1",
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
HYBRID_MODELS = [
|
||||
@@ -35,6 +36,10 @@ HYBRID_MODELS = [
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
@@ -46,24 +51,50 @@ def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
if model != "mistralai/Mamba-Codestral-7B-v0.1":
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v1_outputs = None
|
||||
|
||||
if hf_outputs is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v0_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v0",
|
||||
)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||
name_1="vllm-v1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
|
||||
Reference in New Issue
Block a user