Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com> Co-authored-by: Claude <noreply@anthropic.com>
197 lines
6.2 KiB
Python
197 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.utils.flashinfer import (
|
|
can_use_trtllm_attention,
|
|
supports_trtllm_attention,
|
|
use_trtllm_attention,
|
|
)
|
|
|
|
MODEL_CONFIGS = {
|
|
"Llama-3-70B": dict(num_qo_heads=64, num_kv_heads=8),
|
|
"Llama-3-8B": dict(num_qo_heads=32, num_kv_heads=8),
|
|
"Qwen2.5-0.5B": dict(num_qo_heads=14, num_kv_heads=2),
|
|
"Mistral-7B": dict(num_qo_heads=32, num_kv_heads=8),
|
|
"Gemma-2-9B": dict(num_qo_heads=8, num_kv_heads=4),
|
|
"Falcon-40B": dict(num_qo_heads=128, num_kv_heads=8),
|
|
}
|
|
|
|
|
|
def get_config(model: str) -> dict:
|
|
"""Return the attention config for a model."""
|
|
return MODEL_CONFIGS[model]
|
|
|
|
|
|
DEFAULT_KWARGS = dict(
|
|
**get_config("Llama-3-70B"),
|
|
num_tokens=128,
|
|
max_seq_len=4096,
|
|
dcp_world_size=1,
|
|
kv_cache_dtype="auto",
|
|
q_dtype=torch.bfloat16,
|
|
is_prefill=False,
|
|
force_use_trtllm=None,
|
|
has_sinks=False,
|
|
has_spec=False,
|
|
)
|
|
|
|
|
|
def _call(**overrides) -> bool:
|
|
kwargs = {**DEFAULT_KWARGS, **overrides}
|
|
return use_trtllm_attention(**kwargs)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _clear_supports_cache():
|
|
"""Clear functools.cache to ensure each test runs independently."""
|
|
supports_trtllm_attention.cache_clear()
|
|
|
|
|
|
# supports_trtllm_attention
|
|
|
|
|
|
@patch("vllm.envs.VLLM_BATCH_INVARIANT", True)
|
|
def test_supports_batch_invariant_disables():
|
|
assert supports_trtllm_attention() is False
|
|
|
|
|
|
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
|
|
@patch(
|
|
"vllm.utils.flashinfer.current_platform.is_device_capability",
|
|
return_value=True,
|
|
)
|
|
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
|
|
def test_supports_sm100_with_artifactory(_art, _cap):
|
|
assert supports_trtllm_attention() is True
|
|
|
|
|
|
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
|
|
@patch(
|
|
"vllm.utils.flashinfer.current_platform.is_device_capability",
|
|
return_value=False,
|
|
)
|
|
def test_supports_non_sm100_platform(_cap):
|
|
assert supports_trtllm_attention() is False
|
|
|
|
|
|
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
|
|
@patch(
|
|
"vllm.utils.flashinfer.current_platform.is_device_capability",
|
|
return_value=True,
|
|
)
|
|
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False)
|
|
def test_supports_sm100_without_artifactory(_art, _cap):
|
|
assert supports_trtllm_attention() is False
|
|
|
|
|
|
# can_use_trtllm_attention
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=False)
|
|
def test_can_use_force_disabled(_mock):
|
|
cfg = get_config("Llama-3-70B")
|
|
assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_can_use_compatible_heads(_sup, _force):
|
|
cfg = get_config("Llama-3-70B")
|
|
assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is True
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_can_use_incompatible_heads(_sup, _force):
|
|
assert can_use_trtllm_attention(40, 6) is False
|
|
|
|
|
|
@pytest.mark.parametrize("model", list(MODEL_CONFIGS.keys()))
|
|
@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
|
|
def test_can_use_platform_unsupported(_sup, _force, model):
|
|
cfg = get_config(model)
|
|
assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False
|
|
|
|
|
|
# use_trtllm_attention
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_force_off(_mock):
|
|
assert _call(force_use_trtllm=False) is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_dcp_fallback(_mock):
|
|
assert _call(dcp_world_size=2) is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
|
|
def test_use_platform_unsupported(_mock):
|
|
assert _call() is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
|
|
def test_use_platform_unsupported_force_on_still_false(_mock):
|
|
assert _call(force_use_trtllm=True) is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_incompatible_heads(_mock):
|
|
assert _call(num_qo_heads=40, num_kv_heads=6) is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_incompatible_heads_force_on_still_false(_mock):
|
|
assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_spec_decode_enables(_mock):
|
|
assert _call(has_spec=True, is_prefill=False) is True
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
@patch(
|
|
"vllm.utils.flashinfer.current_platform.fp8_dtype",
|
|
return_value=torch.float8_e4m3fn,
|
|
)
|
|
def test_use_fp8_query_forces_trtllm(_fp8, _sup):
|
|
assert _call(q_dtype=torch.float8_e4m3fn) is True
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_sinks_force_trtllm(_mock):
|
|
assert _call(has_sinks=True) is True
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_auto_prefill_kv_auto(_mock):
|
|
assert _call(is_prefill=True, kv_cache_dtype="auto") is True
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_auto_prefill_kv_fp8(_mock):
|
|
assert _call(is_prefill=True, kv_cache_dtype="fp8") is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_auto_decode_small_batch(_mock):
|
|
assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_auto_decode_large_batch(_mock):
|
|
assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False
|
|
|
|
|
|
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
|
|
def test_use_force_on(_mock):
|
|
assert _call(force_use_trtllm=True) is True
|