diff --git a/tests/models/quantization/test_gpt_oss.py b/tests/models/quantization/test_gpt_oss.py index 7599a5a5e..21cc9555b 100644 --- a/tests/models/quantization/test_gpt_oss.py +++ b/tests/models/quantization/test_gpt_oss.py @@ -21,6 +21,7 @@ import lm_eval import pytest from packaging import version +from vllm.platforms.rocm import on_gfx950 from vllm.utils.torch_utils import cuda_device_count_stateless MODEL_ACCURACIES = { @@ -83,11 +84,17 @@ class EvaluationConfig: @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items()) def test_gpt_oss_attention_quantization( - model_name: str, tp_size: int, expected_accuracy: float + model_name: str, + tp_size: int, + expected_accuracy: float, + monkeypatch: pytest.MonkeyPatch, ): if tp_size > cuda_device_count_stateless(): pytest.skip("Not enough GPUs to run this test case") + if "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8" in model_name and on_gfx950(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + model_args = EvaluationConfig(model_name).get_model_args(tp_size) extra_run_kwargs = {