[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -5,9 +5,10 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
QUANT_DTYPES = [torch.float8_e4m3fn]
|
||||
QUANT_DTYPES = [current_platform.fp8_dtype()]
|
||||
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
@@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
out_shape = (x.shape[0], x.shape[1] // 2)
|
||||
out = torch.empty(out_shape,
|
||||
dtype=torch.torch.float8_e4m3fn,
|
||||
dtype=current_platform.fp8_dtype(),
|
||||
device=x.device)
|
||||
torch.ops._C.silu_and_mul_quant(out, x, scale)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user