[NVIDIA] Support Cutlass w8a8 FP8 for Blackwell Geforce GPUs (sm120) (#17280)

Signed-off-by: kaln27 <liaojuncheng123@foxmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Joonchen Liau
2025-07-02 20:47:19 +08:00
committed by GitHub
parent 0c600b9ab6
commit 9e5552aa13
7 changed files with 238 additions and 1 deletions

View File

@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm