[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-05-14 00:13:56 -05:00
committed by GitHub
parent 2d912fb66f
commit 7b2f28deba
6 changed files with 14 additions and 9 deletions

View File

@@ -27,8 +27,8 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
@@ -36,7 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass)