[CI/Build][AMD] Skip tests in test_fusions_e2e and test_dbo_dp_ep_gsm8k that require non-existing imports for ROCm (#30417)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2025-12-11 18:24:20 -06:00
committed by GitHub
parent d527cf0b3d
commit 48661d275f
2 changed files with 27 additions and 1 deletions

View File

@@ -138,6 +138,17 @@ elif current_platform.is_rocm():
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
def has_cuda_graph_wrapper_metadata() -> bool:
from importlib import import_module
try:
module = import_module("torch._inductor.utils")
module.CUDAGraphWrapperMetadata # noqa B018
except AttributeError:
return False
return True
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
@@ -145,7 +156,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
# quant_fp4 only has the custom impl
+ list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.parametrize(
"inductor_graph_partition",
[
pytest.param(
True,
marks=pytest.mark.skipif(
not has_cuda_graph_wrapper_metadata(),
reason="This test requires"
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
),
),
False,
],
)
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],