[Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-03-31 06:42:18 -05:00
committed by GitHub
parent effc5d24fa
commit e85829450d
8 changed files with 92 additions and 72 deletions

View File

@@ -2,7 +2,6 @@
import pytest
import torch
from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs
import vllm.plugins
@@ -14,9 +13,12 @@ from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
from vllm.platforms import current_platform
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
class TestModel(torch.nn.Module):
@@ -59,8 +61,8 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled):
torch.set_default_device("cuda")