[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
|
||||
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& scale) {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
|
||||
out.dtype() == torch::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||
|
||||
Reference in New Issue
Block a user