Enable V1 for Hybrid SSM/Attention Models (#20016)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -19,31 +20,55 @@ pytestmark = pytest.mark.hybrid_model
|
||||
SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
|
||||
# doesn't compare vLLM output with HF output.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
HYBRID_MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
# NOTE: Currently the test failes due to HF transformers issue fixed in:
|
||||
# https://github.com/huggingface/transformers/pull/39033
|
||||
# We will enable vLLM test for Granite after next HF transformers release.
|
||||
# "ibm-granite/granite-4.0-tiny-preview",
|
||||
# NOTE: Running Plamo2 in transformers implementation requires to install
|
||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||
# not compatible with pip-compile.
|
||||
"pfnet/plamo-2-1b",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
"ibm-ai-platform/Bamba-9B-v1",
|
||||
"nvidia/Nemotron-H-8B-Base-8K",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
]
|
||||
|
||||
HF_UNSUPPORTED_MODELS = [
|
||||
# The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
|
||||
# doesn't compare vLLM output with HF output.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
# Note: I'm not seeing the same output from vLLM V0 vs. HF transformers
|
||||
# for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1
|
||||
"nvidia/Nemotron-H-8B-Base-8K",
|
||||
# NOTE: Currently the test fails due to HF transformers issue fixed in:
|
||||
# https://github.com/huggingface/transformers/pull/39033
|
||||
# We will enable vLLM test for Granite after next HF transformers release.
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
"ibm-ai-platform/Bamba-9B-v1",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"nvidia/Nemotron-H-8B-Base-8K",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
]
|
||||
|
||||
ATTN_BLOCK_SIZES = {
|
||||
"ibm-ai-platform/Bamba-9B-v1": 528,
|
||||
"Zyphra/Zamba2-1.2B-instruct": 80,
|
||||
"nvidia/Nemotron-H-8B-Base-8K": 528,
|
||||
"ibm-granite/granite-4.0-tiny-preview": 400,
|
||||
"tiiuae/Falcon-H1-0.5B-Base": 800,
|
||||
}
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
@@ -60,8 +85,16 @@ def test_models(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
if model != "mistralai/Mamba-Codestral-7B-v0.1":
|
||||
if model not in HF_UNSUPPORTED_MODELS:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
@@ -72,12 +105,21 @@ def test_models(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
|
||||
block_size = ATTN_BLOCK_SIZES[model]
|
||||
else:
|
||||
block_size = 16
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
enable_prefix_caching=False,
|
||||
block_size=block_size) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
@@ -111,6 +153,14 @@ def test_batching(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
|
||||
@@ -169,7 +169,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
|
||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
|
||||
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base",
|
||||
min_transformers_version="4.53"),
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
|
||||
@@ -13,7 +13,6 @@ UNSUPPORTED_MODELS_V1 = [
|
||||
"openai/whisper-large-v3", # transcription
|
||||
"facebook/bart-large-cnn", # encoder decoder
|
||||
"state-spaces/mamba-130m-hf", # mamba1
|
||||
"hmellor/tiny-random-BambaForCausalLM", # hybrid
|
||||
"BAAI/bge-m3", # embedding
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user