[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-05-14 00:13:56 -05:00
committed by GitHub
parent 2d912fb66f
commit 7b2f28deba
6 changed files with 14 additions and 9 deletions

View File

@@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
out.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);