diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index c6f245669..961d6873f 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -18,7 +18,7 @@ from tests.compile.fusion_test_utils import ( is_blackwell, run_model, ) -from tests.utils import cuda_device_count_stateless, flat_product +from tests.utils import flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention.layer import Attention @@ -265,13 +265,13 @@ if current_platform.is_cuda(): HEADS = [(64, 8), (40, 8)] PATTERN_TEST_MODELS_FP8 = [ ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + "RedHatAI/Meta-Llama-3.1-8B-FP8", TestAttentionFp8StaticQuantPatternModel, ) ] PATTERN_TEST_MODELS_FP4 = [ ( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + "nvidia/Llama-3.1-8B-Instruct-NVFP4", TestAttentionNvfp4QuantPatternModel, ) ] @@ -331,9 +331,8 @@ def test_attention_quant_pattern( if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): + # This also captures the FP4 case pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") - if "Llama-4-Scout" in model_name and cuda_device_count_stateless() < 2: - pytest.skip("Llama-4-Scout requires at least 2 GPUs") custom_ops_list = custom_ops.split(",") if custom_ops else []