[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-08-15 14:57:06 +02:00
committed by GitHub
parent 22341b996e
commit 75531a6c13
23 changed files with 467 additions and 87 deletions

View File

@@ -431,3 +431,65 @@ def test_full_cuda_graph(
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_fp32_state(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32",
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)