[Kernel] SM90 CUTLASS FP8 GEMM: add support for swap AB + kernel tuning (#20396)

Signed-off-by: Faqin Zhong <faqin.zhong@gmail.com>
Co-authored-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
lyrisz
2025-07-28 16:13:58 -07:00
committed by GitHub
parent 8aa1485fcf
commit c6c9122d50
3 changed files with 276 additions and 51 deletions

View File

@@ -1,6 +1,5 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
@@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}