[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-05-14 00:13:56 -05:00
committed by GitHub
parent 2d912fb66f
commit 7b2f28deba
6 changed files with 14 additions and 9 deletions

View File

@@ -5,9 +5,10 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [torch.float8_e4m3fn]
QUANT_DTYPES = [current_platform.fp8_dtype()]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
@@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
out_shape = (x.shape[0], x.shape[1] // 2)
out = torch.empty(out_shape,
dtype=torch.torch.float8_e4m3fn,
dtype=current_platform.fp8_dtype(),
device=x.device)
torch.ops._C.silu_and_mul_quant(out, x, scale)
return out