[Kernel] Integrate SM100 MXFP8 blockscaled grouped MM and quant kernels (#34448)
Signed-off-by: EdalatiAli <aliedalati@cohere.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -771,6 +771,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
|
||||
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
|
||||
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8
|
||||
AND ES_MXFP8_GROUPED_MM_ARCHS)
|
||||
message(STATUS "Not building ES MXFP8 grouped kernels as CUDA Compiler version is "
|
||||
"not >= 12.8.")
|
||||
else()
|
||||
message(STATUS "Not building ES MXFP8 grouped kernels as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
|
||||
|
||||
60
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu
Normal file
60
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
|
||||
|
||||
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& sfa,
|
||||
const torch::Tensor& sfb, torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||
"blockscale_offsets must be int32");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
|
||||
TORCH_CHECK(b.dim() == 3,
|
||||
"b must be a 3D tensor of shape (num_experts, k, n)");
|
||||
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
|
||||
"k should align 128");
|
||||
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
|
||||
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
|
||||
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (d.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else if (d.dtype() == torch::kFloat16) {
|
||||
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
|
||||
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"No implemented cutlass_mxfp8_grouped_mm for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
|
||||
}
|
||||
141
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh
Normal file
141
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh
Normal file
@@ -0,0 +1,141 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh
|
||||
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct CutlassMxfp8GroupedMmOffsetFunctor {
|
||||
using Gemm = typename GemmTraits::Gemm;
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
using ElementSF = typename GemmTraits::ElementSF;
|
||||
using ElementD = typename GemmTraits::ElementOutput;
|
||||
// Input
|
||||
int* expert_offsets{nullptr};
|
||||
int* blockscale_offsets{nullptr};
|
||||
// Output
|
||||
ElementA* a_base{nullptr};
|
||||
ElementB* b_base{nullptr};
|
||||
ElementSF* sfa_base{nullptr};
|
||||
ElementSF* sfb_base{nullptr};
|
||||
ElementD* d_base{nullptr};
|
||||
ElementA** a_offsets{nullptr};
|
||||
ElementB** b_offsets{nullptr};
|
||||
ElementSF** sfa_offsets{nullptr};
|
||||
ElementSF** sfb_offsets{nullptr};
|
||||
ElementD** d_offsets{nullptr};
|
||||
|
||||
CutlassMxfp8GroupedMmOffsetFunctor() = default;
|
||||
CutlassMxfp8GroupedMmOffsetFunctor(
|
||||
int* _expert_offsets, int* _blockscale_offsets, ElementA* _a_base,
|
||||
ElementB* _b_base, ElementSF* _sfa_base, ElementSF* _sfb_base,
|
||||
ElementD* _d_base, ElementA** _a_offsets, ElementB** _b_offsets,
|
||||
ElementSF** _sfa_offsets, ElementSF** _sfb_offsets, ElementD** _d_offsets)
|
||||
: expert_offsets{_expert_offsets},
|
||||
blockscale_offsets{_blockscale_offsets},
|
||||
a_base(_a_base),
|
||||
b_base(_b_base),
|
||||
sfa_base(_sfa_base),
|
||||
sfb_base(_sfb_base),
|
||||
d_base(_d_base),
|
||||
a_offsets(_a_offsets),
|
||||
b_offsets(_b_offsets),
|
||||
sfa_offsets(_sfa_offsets),
|
||||
sfb_offsets(_sfb_offsets),
|
||||
d_offsets(_d_offsets) {}
|
||||
|
||||
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
|
||||
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||
int64_t blockscale_offset =
|
||||
static_cast<int64_t>(blockscale_offsets[expert_id]);
|
||||
int64_t a_stride = expert_offset * k;
|
||||
int64_t b_stride = expert_id * k * n;
|
||||
int64_t d_stride = expert_offset * n;
|
||||
int64_t sfa_stride = blockscale_offset * (k / 32);
|
||||
int64_t sfb_stride = expert_id * n * (k / 32);
|
||||
|
||||
a_offsets[expert_id] = a_base + a_stride;
|
||||
b_offsets[expert_id] = b_base + b_stride;
|
||||
sfa_offsets[expert_id] = sfa_base + sfa_stride;
|
||||
sfb_offsets[expert_id] = sfb_base + sfb_stride;
|
||||
d_offsets[expert_id] = d_base + d_stride;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct CutlassMxfp8GroupedMmLayoutFunctor {
|
||||
using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig;
|
||||
using LayoutSFA = typename GemmTraits::LayoutSFA;
|
||||
using LayoutSFB = typename GemmTraits::LayoutSFB;
|
||||
LayoutSFA* layout_sfa_base{nullptr};
|
||||
LayoutSFB* layout_sfb_base{nullptr};
|
||||
|
||||
CutlassMxfp8GroupedMmLayoutFunctor() = default;
|
||||
CutlassMxfp8GroupedMmLayoutFunctor(LayoutSFA* _layout_sfa_base,
|
||||
LayoutSFB* _layout_sfb_base)
|
||||
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
|
||||
|
||||
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
|
||||
*layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
*layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct CutlassMxfp8GroupedMmStrideFunctor {
|
||||
using StrideA = typename GemmTraits::StrideA;
|
||||
using StrideB = typename GemmTraits::StrideB;
|
||||
using StrideD = typename GemmTraits::StrideD;
|
||||
StrideA* stride_A_base{nullptr};
|
||||
StrideB* stride_B_base{nullptr};
|
||||
StrideD* stride_D_base{nullptr};
|
||||
|
||||
CutlassMxfp8GroupedMmStrideFunctor() = default;
|
||||
CutlassMxfp8GroupedMmStrideFunctor(StrideA* _stride_A_base,
|
||||
StrideB* _stride_B_base,
|
||||
StrideD* _stride_D_base)
|
||||
: stride_A_base(_stride_A_base),
|
||||
stride_B_base(_stride_B_base),
|
||||
stride_D_base(_stride_D_base) {}
|
||||
|
||||
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
|
||||
StrideA* stride_A = stride_A_base + expert_id;
|
||||
StrideB* stride_B = stride_B_base + expert_id;
|
||||
StrideD* stride_D = stride_D_base + expert_id;
|
||||
*stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
|
||||
*stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
||||
*stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetFunctor, typename LayoutFunctor,
|
||||
typename StrideFunctor>
|
||||
__global__ void cutlassMxfp8GroupedMmPreComputeKernel(
|
||||
int* problem_sizes, OffsetFunctor offset_functor,
|
||||
LayoutFunctor layout_functor, StrideFunctor stride_functor) {
|
||||
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
|
||||
int m = problem_sizes[expert_id * 3 + 0];
|
||||
int n = problem_sizes[expert_id * 3 + 1];
|
||||
int k = problem_sizes[expert_id * 3 + 2];
|
||||
|
||||
offset_functor(expert_id, m, n, k);
|
||||
layout_functor(expert_id, m, n, k);
|
||||
stride_functor(expert_id, m, n, k);
|
||||
}
|
||||
|
||||
} // namespace expert_specialization
|
||||
179
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh
Normal file
179
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh
Normal file
@@ -0,0 +1,179 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
|
||||
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
|
||||
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
template <typename GemmTraits>
|
||||
void cutlass_mxfp8_grouped_mm_pre_compute(
|
||||
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
|
||||
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
|
||||
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
|
||||
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
|
||||
using ElementA = typename OffsetFunctor::ElementA;
|
||||
using ElementB = typename OffsetFunctor::ElementB;
|
||||
using ElementSF = typename OffsetFunctor::ElementSF;
|
||||
using ElementD = typename OffsetFunctor::ElementD;
|
||||
|
||||
using LayoutFunctor = CutlassMxfp8GroupedMmLayoutFunctor<GemmTraits>;
|
||||
using LayoutSFA = typename LayoutFunctor::LayoutSFA;
|
||||
using LayoutSFB = typename LayoutFunctor::LayoutSFB;
|
||||
|
||||
using StrideFunctor = CutlassMxfp8GroupedMmStrideFunctor<GemmTraits>;
|
||||
using StrideA = typename StrideFunctor::StrideA;
|
||||
using StrideB = typename StrideFunctor::StrideB;
|
||||
using StrideD = typename StrideFunctor::StrideD;
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
TORCH_CHECK(num_experts <= 1024,
|
||||
"Number of experts cannot exceed 1024, the maximum number of "
|
||||
"threads per block.");
|
||||
|
||||
OffsetFunctor offset_functor(
|
||||
reinterpret_cast<int*>(expert_offsets.data_ptr()),
|
||||
reinterpret_cast<int*>(blockscale_offsets.data_ptr()),
|
||||
reinterpret_cast<ElementA*>(a.data_ptr()),
|
||||
reinterpret_cast<ElementB*>(b.data_ptr()),
|
||||
reinterpret_cast<ElementSF*>(sfa.data_ptr()),
|
||||
reinterpret_cast<ElementSF*>(sfb.data_ptr()),
|
||||
reinterpret_cast<ElementD*>(d.data_ptr()),
|
||||
reinterpret_cast<ElementA**>(a_ptrs.data_ptr()),
|
||||
reinterpret_cast<ElementB**>(b_ptrs.data_ptr()),
|
||||
reinterpret_cast<ElementSF**>(sfa_ptrs.data_ptr()),
|
||||
reinterpret_cast<ElementSF**>(sfb_ptrs.data_ptr()),
|
||||
reinterpret_cast<ElementD**>(d_ptrs.data_ptr()));
|
||||
LayoutFunctor layout_functor(
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));
|
||||
StrideFunctor stride_functor(reinterpret_cast<StrideA*>(stride_a.data_ptr()),
|
||||
reinterpret_cast<StrideB*>(stride_b.data_ptr()),
|
||||
reinterpret_cast<StrideD*>(stride_d.data_ptr()));
|
||||
cutlassMxfp8GroupedMmPreComputeKernel<<<1, num_experts, 0, stream>>>(
|
||||
static_cast<int*>(problem_sizes.data_ptr()), offset_functor,
|
||||
layout_functor, stride_functor);
|
||||
}
|
||||
|
||||
template <typename GemmTraits>
|
||||
void cutlass_mxfp8_grouped_mm(
|
||||
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
|
||||
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
|
||||
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, cudaStream_t stream) {
|
||||
using Gemm = typename GemmTraits::Gemm;
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
using ElementSF = typename GemmTraits::ElementSF;
|
||||
using ElementD = typename GemmTraits::ElementOutput;
|
||||
using StrideA = typename GemmTraits::StrideA;
|
||||
using StrideB = typename GemmTraits::StrideB;
|
||||
using StrideD = typename GemmTraits::StrideD;
|
||||
using LayoutSFA = typename GemmTraits::LayoutSFA;
|
||||
using LayoutSFB = typename GemmTraits::LayoutSFB;
|
||||
using UnderlyingProblemShape =
|
||||
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = c10::cuda::current_device();
|
||||
hw_info.sm_count =
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
|
||||
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
|
||||
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
|
||||
UnderlyingProblemShape* underlying_problem_shape =
|
||||
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
typename Gemm::Arguments arguments = {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, underlying_problem_shape, nullptr},
|
||||
{reinterpret_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
reinterpret_cast<StrideA*>(stride_a.data_ptr()),
|
||||
reinterpret_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
reinterpret_cast<StrideB*>(stride_b.data_ptr()),
|
||||
reinterpret_cast<const ElementSF**>(sfa_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||
reinterpret_cast<const ElementSF**>(sfb_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())},
|
||||
{{},
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<ElementD**>(d_ptrs.data_ptr()),
|
||||
reinterpret_cast<StrideD*>(stride_d.data_ptr())},
|
||||
hw_info,
|
||||
{} // Scheduler
|
||||
};
|
||||
|
||||
Gemm gemm;
|
||||
|
||||
auto can_implement_status = gemm.can_implement(arguments);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
torch::TensorOptions options_uint8 =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
|
||||
size_t workspace_size = gemm.get_workspace_size(arguments);
|
||||
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
|
||||
|
||||
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm.run(stream, nullptr, true); // Enable PDL
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
|
||||
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
|
||||
const torch::Tensor& sfb, torch::Tensor& d,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
torch::TensorOptions options_int64 =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::TensorOptions options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
|
||||
|
||||
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
|
||||
|
||||
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
|
||||
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
|
||||
a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d,
|
||||
layout_sfa, layout_sfb, a, b, sfa, sfb, d, problem_sizes, expert_offsets,
|
||||
blockscale_offsets, stream);
|
||||
cutlass_mxfp8_grouped_mm<GemmTraits>(
|
||||
a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d,
|
||||
layout_sfa, layout_sfb, problem_sizes, stream);
|
||||
}
|
||||
|
||||
} // namespace expert_specialization
|
||||
127
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh
Normal file
127
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh
Normal file
@@ -0,0 +1,127 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_traits.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
// Misc
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_size.h"
|
||||
|
||||
// Collective Builder
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
// Integration
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// Different configs for 1SM and 2SM MMA kernel
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
const static dim3 preferred_cluster;
|
||||
const static dim3 fallback_cluster;
|
||||
};
|
||||
const dim3 MMA1SMConfig::preferred_cluster(1, 4, 1);
|
||||
const dim3 MMA1SMConfig::fallback_cluster(1, 2, 1);
|
||||
|
||||
template <typename _MMAConfig, typename OutputDtype>
|
||||
struct CutlassMxfp8GroupedMmGemmTraits {
|
||||
using MMAConfig = _MMAConfig;
|
||||
using ElementInput = cutlass::float_e4m3_t;
|
||||
using ElementOutput = OutputDtype;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::mx_float8_t<ElementInput>;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
constexpr static int AlignmentA = 32;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::mx_float8_t<ElementInput>;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
constexpr static int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = void;
|
||||
using ElementD = ElementOutput;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
constexpr static int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
constexpr static int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using CustomEVTIdentity = // acc
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator,
|
||||
RoundStyle>,
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
// Core kernel configurations
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
|
||||
// Runtime Cluster Shape
|
||||
using ClusterShape = Shape<int32_t, int32_t, _1>;
|
||||
|
||||
// Define Epilogue
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, typename MMAConfig::MmaTileShape,
|
||||
ClusterShape, Shape<_64, _64>, ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC*, AlignmentC, ElementD, LayoutD*, AlignmentD,
|
||||
typename MMAConfig::EpilogueSchedule,
|
||||
CustomEVTIdentity>::CollectiveOp;
|
||||
|
||||
// Define Mainloop
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
|
||||
LayoutB*, AlignmentB, ElementAccumulator,
|
||||
typename MMAConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMAConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
// Define GemmKernel
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using ElementSF = typename Gemm::GemmKernel::ElementSF;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
using LayoutSFA =
|
||||
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB =
|
||||
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using Sm1xxBlkScaledConfig =
|
||||
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
};
|
||||
|
||||
} // namespace expert_specialization
|
||||
60
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu
Normal file
60
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "mxfp8_experts_quant.cuh"
|
||||
|
||||
void mxfp8_experts_quant(const torch::Tensor& input,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets,
|
||||
torch::Tensor& quant_output,
|
||||
torch::Tensor& scale_factor) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
|
||||
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
|
||||
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
|
||||
"blockscale_offsets must be int32");
|
||||
|
||||
auto groups = problem_sizes.size(0);
|
||||
TORCH_CHECK(
|
||||
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
|
||||
"expert_offsets must be 1D and have size equal to the number of groups");
|
||||
TORCH_CHECK(
|
||||
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
|
||||
"blockscale_offsets must be 1D and have size equal to the number of "
|
||||
"groups");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (input.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else if (input.dtype() == torch::kFloat16) {
|
||||
expert_specialization::launch_mxfp8_experts_quant<__half>(
|
||||
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
|
||||
scale_factor);
|
||||
} else {
|
||||
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"No implemented mxfp8_experts_quant for "
|
||||
"current device");
|
||||
#endif
|
||||
}
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
|
||||
}
|
||||
414
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh
Normal file
414
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh
Normal file
@@ -0,0 +1,414 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
// Adapted from SGLang:
|
||||
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
|
||||
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cuda/ptx>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
constexpr uint32_t THREAD_BLOCK_SIZE = 128;
|
||||
constexpr uint32_t WARP_SIZE = 32;
|
||||
constexpr int BLOCK_M = 128;
|
||||
constexpr int BLOCK_K = 128;
|
||||
using ThrLayout = Layout<Shape<_16, _8>, Stride<_8, _1>>;
|
||||
using ValLayout = Layout<Shape<_1, _16>>;
|
||||
using SfR2SThrLayout = Layout<Shape<_16, _4>, Stride<_4, _1>>;
|
||||
using SfR2SValLayout = Layout<Shape<_1, _1>>;
|
||||
using ScaleFactorTileLayout =
|
||||
Layout<Shape<Shape<_32, _4>, _4>, Stride<Stride<_16, _4>, _1>>;
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
// Some code references TRT-LLM:
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/quantization.cuh
|
||||
template <typename FragmentS, typename FragmentD>
|
||||
__inline__ __device__ uint8_t cvt_warp_fp16_to_mxfp8(FragmentS& fragment_s,
|
||||
FragmentD& fragment_d) {
|
||||
using FragmentSLayout = typename FragmentS::layout_type;
|
||||
using FragmentDLayout = typename FragmentD::layout_type;
|
||||
FragmentSLayout fragment_s_layout;
|
||||
FragmentDLayout fragment_d_layout;
|
||||
static_assert(is_static<FragmentSLayout>::value &&
|
||||
size(fragment_s_layout) == 16);
|
||||
static_assert(is_static<FragmentDLayout>::value &&
|
||||
size(fragment_d_layout) == 16);
|
||||
|
||||
constexpr int eles_per_thr = 16;
|
||||
using ValType = typename FragmentS::element_type;
|
||||
using VecType = std::conditional_t<std::is_same_v<ValType, __nv_bfloat16>,
|
||||
__nv_bfloat162, __half2>;
|
||||
VecType vec[8];
|
||||
// Assign vals
|
||||
vec[0].x = fragment_s(Int<0>{});
|
||||
vec[0].y = fragment_s(Int<1>{});
|
||||
vec[1].x = fragment_s(Int<2>{});
|
||||
vec[1].y = fragment_s(Int<3>{});
|
||||
vec[2].x = fragment_s(Int<4>{});
|
||||
vec[2].y = fragment_s(Int<5>{});
|
||||
vec[3].x = fragment_s(Int<6>{});
|
||||
vec[3].y = fragment_s(Int<7>{});
|
||||
vec[4].x = fragment_s(Int<8>{});
|
||||
vec[4].y = fragment_s(Int<9>{});
|
||||
vec[5].x = fragment_s(Int<10>{});
|
||||
vec[5].y = fragment_s(Int<11>{});
|
||||
vec[6].x = fragment_s(Int<12>{});
|
||||
vec[6].y = fragment_s(Int<13>{});
|
||||
vec[7].x = fragment_s(Int<14>{});
|
||||
vec[7].y = fragment_s(Int<15>{});
|
||||
|
||||
auto local_max = __habs2(vec[0]);
|
||||
for (int i = 1; i < eles_per_thr / 2; i++) {
|
||||
local_max = __hmax2(__habs2(vec[i]), local_max);
|
||||
}
|
||||
local_max = __hmax2(__shfl_xor_sync(uint32_t(-1), local_max, 1), local_max);
|
||||
|
||||
// Get the final absolute maximum values.
|
||||
float block_max(0.0f);
|
||||
if constexpr (std::is_same_v<ValType, __nv_bfloat16>) {
|
||||
block_max = __bfloat162float(__hmax(local_max.x, local_max.y));
|
||||
} else {
|
||||
block_max = __half2float(__hmax(local_max.x, local_max.y));
|
||||
}
|
||||
// Get the SF (max value of the vector / max value of mxfp8).
|
||||
float sf_val = block_max * reciprocal_approximate_ftz(448.0f);
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8_sf_val;
|
||||
|
||||
__nv_fp8_e8m0 tmp_sf_val;
|
||||
tmp_sf_val.__x =
|
||||
__nv_cvt_float_to_e8m0(sf_val, __NV_SATFINITE, cudaRoundPosInf);
|
||||
sf_val = static_cast<float>(tmp_sf_val);
|
||||
fp8_sf_val = tmp_sf_val.__x;
|
||||
// Get the output scale (reciprocal of the SFValue).
|
||||
float output_scale =
|
||||
block_max != 0.f ? reciprocal_approximate_ftz(sf_val) : 0.0f;
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2_vals[eles_per_thr / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < eles_per_thr / 2; i++) {
|
||||
if constexpr (std::is_same_v<ValType, __half>) {
|
||||
fp2_vals[i] = __half22float2(vec[i]);
|
||||
} else {
|
||||
fp2_vals[i] = __bfloat1622float2(vec[i]);
|
||||
}
|
||||
fp2_vals[i].x *= output_scale;
|
||||
fp2_vals[i].y *= output_scale;
|
||||
}
|
||||
union {
|
||||
uint8_t bytes[16];
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
} u;
|
||||
u.elts[0] = __nv_fp8x2_e4m3(fp2_vals[0]);
|
||||
u.elts[1] = __nv_fp8x2_e4m3(fp2_vals[1]);
|
||||
u.elts[2] = __nv_fp8x2_e4m3(fp2_vals[2]);
|
||||
u.elts[3] = __nv_fp8x2_e4m3(fp2_vals[3]);
|
||||
u.elts[4] = __nv_fp8x2_e4m3(fp2_vals[4]);
|
||||
u.elts[5] = __nv_fp8x2_e4m3(fp2_vals[5]);
|
||||
u.elts[6] = __nv_fp8x2_e4m3(fp2_vals[6]);
|
||||
u.elts[7] = __nv_fp8x2_e4m3(fp2_vals[7]);
|
||||
fragment_d(Int<0>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[0]);
|
||||
fragment_d(Int<1>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[1]);
|
||||
fragment_d(Int<2>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[2]);
|
||||
fragment_d(Int<3>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[3]);
|
||||
fragment_d(Int<4>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[4]);
|
||||
fragment_d(Int<5>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[5]);
|
||||
fragment_d(Int<6>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[6]);
|
||||
fragment_d(Int<7>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[7]);
|
||||
fragment_d(Int<8>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[8]);
|
||||
fragment_d(Int<9>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[9]);
|
||||
fragment_d(Int<10>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[10]);
|
||||
fragment_d(Int<11>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[11]);
|
||||
fragment_d(Int<12>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[12]);
|
||||
fragment_d(Int<13>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[13]);
|
||||
fragment_d(Int<14>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[14]);
|
||||
fragment_d(Int<15>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[15]);
|
||||
return fp8_sf_val;
|
||||
}
|
||||
|
||||
template <typename TensorS, typename TensorP, typename TensorD,
|
||||
typename TensorSharedSF, typename TensorSF, typename TiledCopyG2R,
|
||||
typename TiledCopyR2G, typename TiledCopyR2S>
|
||||
__inline__ __device__ void mxfp8_experts_quant_tile(
|
||||
TensorS& tensor_s, TensorP& tensor_p, TensorD& tensor_d,
|
||||
TensorSharedSF& tensor_shared_sf, TensorSF& tensor_sf, int m,
|
||||
TiledCopyG2R& tiled_copy_g2r, TiledCopyR2G& tiled_copy_r2g,
|
||||
TiledCopyR2S& tiled_copy_r2s) {
|
||||
static_assert(size(get<0>(typename TensorS::layout_type{})) == 128 &&
|
||||
size(get<1>(typename TensorS::layout_type{})) == 128 &&
|
||||
stride(get<1>(typename TensorS::layout_type{})) == 1);
|
||||
static_assert(size(get<0>(typename TensorD::layout_type{})) == 128 &&
|
||||
size(get<1>(typename TensorD::layout_type{})) == 128 &&
|
||||
stride(get<1>(typename TensorD::layout_type{})) == 1);
|
||||
static_assert(size(get<0>(typename TensorP::layout_type{})) == 128 &&
|
||||
size(get<1>(typename TensorP::layout_type{})) == 128);
|
||||
static_assert(size(get<0>(typename TensorSharedSF::layout_type{})) == 128 &&
|
||||
size(get<1>(typename TensorSharedSF::layout_type{})) == 4);
|
||||
static_assert(size(get<0>(typename TensorSF::layout_type{})) == 128 &&
|
||||
size(get<1>(typename TensorSF::layout_type{})) == 4);
|
||||
|
||||
using Tiler_MN = typename TiledCopyG2R::Tiler_MN;
|
||||
auto tiler_mn = Tiler_MN{};
|
||||
static_assert(size<0>(tiler_mn) == 16 && size<1>(tiler_mn) == 128);
|
||||
|
||||
auto tiled_tensor_s = tiled_divide(tensor_s, tiler_mn);
|
||||
auto tiled_tensor_p = tiled_divide(tensor_p, tiler_mn);
|
||||
auto tiled_tensor_d = tiled_divide(tensor_d, tiler_mn);
|
||||
static_assert(size<2>(tiled_tensor_s) == 1);
|
||||
static_assert(size<2>(tiled_tensor_p) == 1);
|
||||
static_assert(size<2>(tiled_tensor_d) == 1);
|
||||
auto squeeze_tiled_tensor_s = take<0, 2>(tiled_tensor_s);
|
||||
auto squeeze_tiled_tensor_p = take<0, 2>(tiled_tensor_p);
|
||||
auto squeeze_tiled_tensor_d = take<0, 2>(tiled_tensor_d);
|
||||
|
||||
using SF_Tiler_MN = typename TiledCopyR2S::Tiler_MN;
|
||||
auto sf_tiler_mn = SF_Tiler_MN{};
|
||||
static_assert(size<0>(sf_tiler_mn) == 16 && size<1>(sf_tiler_mn) == 4);
|
||||
|
||||
auto tiled_tensor_sf = tiled_divide(tensor_sf, sf_tiler_mn);
|
||||
auto tiled_tensor_shared_sf = tiled_divide(tensor_shared_sf, sf_tiler_mn);
|
||||
auto squeeze_tiled_tensor_sf = take<0, 2>(tiled_tensor_sf);
|
||||
auto squeeze_tiled_tensor_shared_sf = take<0, 2>(tiled_tensor_shared_sf);
|
||||
|
||||
constexpr int tile_loop_count = size<1>(tiled_tensor_s);
|
||||
constexpr int rows_in_tile = 16;
|
||||
// We don't need to clear shared memory
|
||||
// clear(squeeze_tiled_tensor_shared_sf);
|
||||
#pragma unroll 4
|
||||
for (int t = 0; t < tile_loop_count; t++) {
|
||||
if (t * rows_in_tile >= m) {
|
||||
break;
|
||||
}
|
||||
auto current_copy_tile_s = tensor<0>(squeeze_tiled_tensor_s(_, t));
|
||||
auto current_copy_tile_p = tensor<0>(squeeze_tiled_tensor_p(_, t));
|
||||
auto current_copy_tile_d = tensor<0>(squeeze_tiled_tensor_d(_, t));
|
||||
auto current_copy_tile_sf = tensor<0>(squeeze_tiled_tensor_sf(_, t));
|
||||
auto current_copy_tile_shared_sf =
|
||||
tensor<0>(squeeze_tiled_tensor_shared_sf(_, t));
|
||||
|
||||
// Global to Register copy
|
||||
auto thr_copy_g2r = tiled_copy_g2r.get_thread_slice(threadIdx.x);
|
||||
auto thr_tile_g2r_s = thr_copy_g2r.partition_S(current_copy_tile_s);
|
||||
auto thr_tile_g2r_p = thr_copy_g2r.partition_S(current_copy_tile_p);
|
||||
auto input_fragment = make_fragment_like(thr_tile_g2r_s);
|
||||
|
||||
// Register to Global copy
|
||||
auto thr_copy_r2g = tiled_copy_r2g.get_thread_slice(threadIdx.x);
|
||||
auto thr_tile_r2g_d = thr_copy_r2g.partition_D(current_copy_tile_d);
|
||||
auto thr_tile_r2g_p = thr_copy_r2g.partition_D(current_copy_tile_p);
|
||||
auto output_fragment = make_fragment_like(thr_tile_r2g_d);
|
||||
|
||||
// Register to Shared copy
|
||||
auto thr_copy_r2s = tiled_copy_r2s.get_thread_slice(threadIdx.x / 2);
|
||||
auto thr_tile_r2s_shared_sf =
|
||||
thr_copy_r2s.partition_D(current_copy_tile_shared_sf);
|
||||
auto shared_sf_fragment = make_fragment_like(thr_tile_r2s_shared_sf);
|
||||
|
||||
// CopyG2R & convert & CopyR2G
|
||||
copy_if(tiled_copy_g2r, thr_tile_g2r_p, thr_tile_g2r_s, input_fragment);
|
||||
uint8_t fp8_sf_val =
|
||||
cvt_warp_fp16_to_mxfp8(input_fragment, output_fragment);
|
||||
copy_if(tiled_copy_r2g, thr_tile_r2g_p, output_fragment, thr_tile_r2g_d);
|
||||
shared_sf_fragment[0] = fp8_sf_val;
|
||||
|
||||
// Before first copy r2s, clear shared memory and wait previous group
|
||||
if (t == 0 && threadIdx.x == 0) {
|
||||
// Wait for the group to have completed reading from shared memory.
|
||||
cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>());
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x % 2 == 0) {
|
||||
copy(tiled_copy_r2s, shared_sf_fragment, thr_tile_r2s_shared_sf);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Wait for shared memory writes to be visible to TMA engine.
|
||||
cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); // b)
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cuda::ptx::cp_async_bulk(cuda::ptx::space_global, cuda::ptx::space_shared,
|
||||
squeeze_tiled_tensor_sf.data().get(),
|
||||
squeeze_tiled_tensor_shared_sf.data().get(), 512);
|
||||
// Wait for TMA transfer to have finished reading shared memory.
|
||||
// Create a "bulk async-group" out of the previous bulk copy operation.
|
||||
cuda::ptx::cp_async_bulk_commit_group();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T_IN, typename TiledCopyG2R, typename TiledCopyR2G,
|
||||
typename TiledCopyR2S>
|
||||
__global__ void mxfp8_experts_quant_kernel(
|
||||
const T_IN* input, const int* problem_sizes, const int* expert_offsets,
|
||||
const int* blockscale_offsets, cutlass::float_e4m3_t* quant_output,
|
||||
uint8_t* scale_factor, int groups, TiledCopyG2R tiled_copy_g2r,
|
||||
TiledCopyR2G tiled_copy_r2g, TiledCopyR2S tiled_copy_r2s) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
|
||||
__shared__ __align__(512) uint8_t shared_memory[512];
|
||||
ScaleFactorTileLayout scale_factor_tile_layout{};
|
||||
auto scale_factor_shared =
|
||||
make_tensor(make_smem_ptr(shared_memory),
|
||||
scale_factor_tile_layout); // ((_32,_4), _4):((_16,_4), _1)
|
||||
// TODO: Transform Groupwise Schedule into a more efficient Schedule
|
||||
for (int g = 0; g < groups; g++) {
|
||||
int m = problem_sizes[g * 3 + 0];
|
||||
int k = problem_sizes[g * 3 + 2];
|
||||
int64_t expert_offset = static_cast<int64_t>(expert_offsets[g]);
|
||||
int64_t blockscale_offset = static_cast<int64_t>(blockscale_offsets[g]);
|
||||
|
||||
auto input_tensor = make_tensor(
|
||||
make_gmem_ptr(input + expert_offset * k),
|
||||
make_layout(make_shape(m, k),
|
||||
LayoutRight{})); // (M, K):(K, 1) half_t/bfloat16_t
|
||||
|
||||
auto quant_output_tensor = make_tensor(
|
||||
make_gmem_ptr(quant_output + expert_offset * k),
|
||||
make_layout(make_shape(m, k),
|
||||
LayoutRight{})); // (M, K):(K, 1) cutlass::float_e4m3_t
|
||||
|
||||
auto scale_factor_shape = make_shape(ceil_div(m, 128) * 128, k / 32);
|
||||
auto scale_factor_layout = tile_to_shape(scale_factor_tile_layout,
|
||||
scale_factor_shape, LayoutRight{});
|
||||
// layout<0>(layout<0>(scale_factor_layout)) (_32,_4):(_16,_4) -- static
|
||||
// layout<1>(layout<0>(scale_factor_layout)) M_align_128 / 128 -- dynamic
|
||||
// shape dynamic stride layout<0>(layout<1>(scale_factor_layout)) _4:_1 --
|
||||
// static layout<1>(layout<1>(scale_factor_layout)) (K / 32) / 4 : _512 --
|
||||
// dynamic shape static stride
|
||||
|
||||
// Reshape to zipped layout for 1D indexing
|
||||
auto zipped_scale_factor_layout = make_layout(
|
||||
make_layout(layout<0>(layout<0>(scale_factor_layout)),
|
||||
layout<0>(layout<1>(scale_factor_layout))),
|
||||
make_layout(
|
||||
layout<1>(layout<0>(scale_factor_layout)),
|
||||
layout<1>(layout<1>(
|
||||
scale_factor_layout)))); // (((_32,_4),_4),(M_align_128 /
|
||||
// 128,(K / 32) /
|
||||
// 4)):(((_16,_4),_1),(?,_512))
|
||||
|
||||
auto scale_factor_tensor =
|
||||
make_tensor(make_gmem_ptr(scale_factor + blockscale_offset * (k / 32)),
|
||||
zipped_scale_factor_layout);
|
||||
|
||||
// Used for cases where M is not divisible by 128 (most scenarios).
|
||||
auto input_shape = shape(input_tensor); // (M, K):(K, 1)
|
||||
auto identity_tensor = make_identity_tensor(input_shape);
|
||||
auto predict_tensor = cute::lazy::transform(
|
||||
identity_tensor, [&](auto c) { return elem_less(c, input_shape); });
|
||||
|
||||
// (_128, _128)
|
||||
auto tiler = make_shape(Int<BLOCK_M>{}, Int<BLOCK_K>{});
|
||||
|
||||
auto tiled_input_tensor = zipped_divide(
|
||||
input_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
|
||||
auto tiled_quant_output_tensor =
|
||||
zipped_divide(quant_output_tensor,
|
||||
tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
|
||||
auto tiled_predict_tensor = zipped_divide(
|
||||
predict_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
|
||||
|
||||
auto total_tiles =
|
||||
size<1>(tiled_input_tensor); // cdiv(M, 128) * cdiv(K, 128)
|
||||
decltype(total_tiles) blk_offset = blockIdx.x;
|
||||
while (blk_offset < total_tiles) {
|
||||
auto current_input_tile = tensor<0>(tiled_input_tensor(_, blk_offset));
|
||||
auto current_quant_output_tile =
|
||||
tensor<0>(tiled_quant_output_tensor(_, blk_offset));
|
||||
auto current_predict_tile =
|
||||
tensor<0>(tiled_predict_tensor(_, blk_offset));
|
||||
auto current_scale_factor_tile =
|
||||
tensor<0>(scale_factor_tensor(_, blk_offset));
|
||||
|
||||
mxfp8_experts_quant_tile<
|
||||
decltype(current_input_tile), decltype(current_predict_tile),
|
||||
decltype(current_quant_output_tile), decltype(scale_factor_shared),
|
||||
decltype(current_scale_factor_tile), TiledCopyG2R, TiledCopyR2G,
|
||||
TiledCopyR2S>(current_input_tile, current_predict_tile,
|
||||
current_quant_output_tile, scale_factor_shared,
|
||||
current_scale_factor_tile, m, tiled_copy_g2r,
|
||||
tiled_copy_r2g, tiled_copy_r2s);
|
||||
blk_offset += gridDim.x;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T_IN>
|
||||
void launch_mxfp8_experts_quant(const torch::Tensor& input,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& blockscale_offsets,
|
||||
torch::Tensor& quant_output,
|
||||
torch::Tensor& scale_factor) {
|
||||
ThrLayout thr_layout{};
|
||||
ValLayout val_layout{};
|
||||
SfR2SThrLayout r2s_thr_layout{};
|
||||
SfR2SValLayout r2s_val_layout{};
|
||||
|
||||
using CopyOpG2R =
|
||||
UniversalCopy<cutlass::AlignedArray<T_IN, size(val_layout)>>;
|
||||
using CopyAtomG2R = cute::Copy_Atom<CopyOpG2R, T_IN>;
|
||||
auto tiled_copy_g2r = cute::make_tiled_copy(
|
||||
CopyAtomG2R{}, thr_layout, val_layout); // Tiler_MN: (16, 128)
|
||||
|
||||
using CopyOpR2G = UniversalCopy<
|
||||
cutlass::AlignedArray<cutlass::float_e4m3_t, size(val_layout)>>;
|
||||
using CopyAtomR2G = cute::Copy_Atom<CopyOpR2G, cutlass::float_e4m3_t>;
|
||||
auto tiled_copy_r2g = cute::make_tiled_copy(
|
||||
CopyAtomR2G{}, thr_layout, val_layout); // Tiler_MN: (16, 128)
|
||||
|
||||
using CopyOpR2S =
|
||||
UniversalCopy<cutlass::AlignedArray<uint8_t, size(r2s_val_layout)>>;
|
||||
using CopyAtomR2S = cute::Copy_Atom<CopyOpR2S, uint8_t>;
|
||||
auto tiled_copy_r2s = cute::make_tiled_copy(
|
||||
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
|
||||
|
||||
int max_active_blocks_per_sm = -1;
|
||||
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks_per_sm,
|
||||
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||
decltype(tiled_copy_r2g),
|
||||
decltype(tiled_copy_r2s)>,
|
||||
THREAD_BLOCK_SIZE, 0));
|
||||
|
||||
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
|
||||
max_active_blocks_per_sm,
|
||||
1, 1);
|
||||
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
|
||||
int num_experts = (int)problem_sizes.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
|
||||
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const T_IN*>(input.data_ptr()),
|
||||
reinterpret_cast<const int*>(problem_sizes.data_ptr()),
|
||||
reinterpret_cast<const int*>(expert_offsets.data_ptr()),
|
||||
reinterpret_cast<const int*>(blockscale_offsets.data_ptr()),
|
||||
reinterpret_cast<cutlass::float_e4m3_t*>(quant_output.data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(scale_factor.data_ptr()), num_experts,
|
||||
tiled_copy_g2r, tiled_copy_r2g, tiled_copy_r2s);
|
||||
}
|
||||
|
||||
} // namespace expert_specialization
|
||||
@@ -426,6 +426,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
|
||||
ops.def(
|
||||
"mxfp8_experts_quant("
|
||||
" Tensor input, Tensor problem_sizes, Tensor expert_offsets,"
|
||||
" Tensor blockscale_offsets, Tensor! quant_output, Tensor! scale_factor)"
|
||||
" -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// Expert-specialization mxfp8 blockscaled grouped GEMM (SM100+).
|
||||
ops.def(
|
||||
"cutlass_mxfp8_grouped_mm("
|
||||
" Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor! out,"
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets)"
|
||||
" -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
|
||||
237
tests/kernels/moe/test_cutlass_mxfp8_grouped_mm.py
Normal file
237
tests/kernels/moe/test_cutlass_mxfp8_grouped_mm.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from SGLang:
|
||||
# https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/tests/test_es_fp8_blockwise_moe.py
|
||||
|
||||
"""Tests for SM100 CUTLASS MXFP8 grouped MoE kernels."""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import torch_moe_single
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
random.seed(42)
|
||||
set_random_seed(42)
|
||||
|
||||
|
||||
def align(val: int, alignment: int = 128) -> int:
|
||||
return int((val + alignment - 1) // alignment * alignment)
|
||||
|
||||
|
||||
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def is_sm100_supported() -> bool:
|
||||
return current_platform.is_cuda() and current_platform.is_device_capability_family(
|
||||
100
|
||||
)
|
||||
|
||||
|
||||
def compute_ref_output(
|
||||
input_tensor: torch.Tensor,
|
||||
weight_list: list[torch.Tensor],
|
||||
expert_offsets: list[int],
|
||||
expert_offset: int,
|
||||
num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
# Build a top-1 routing score so each token maps to its owning expert.
|
||||
score = torch.full(
|
||||
(expert_offset, num_experts),
|
||||
-1e9,
|
||||
device=input_tensor.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for g in range(num_experts):
|
||||
start = expert_offsets[g]
|
||||
end = expert_offsets[g + 1] if g + 1 < num_experts else expert_offset
|
||||
score[start:end, g] = 0.0
|
||||
|
||||
return torch_moe_single(
|
||||
input_tensor, torch.stack(weight_list, dim=0), score, topk=1
|
||||
)
|
||||
|
||||
|
||||
def compute_kernel_output(
|
||||
input_tensor: torch.Tensor,
|
||||
weight_tensor: torch.Tensor,
|
||||
problem_sizes: list[list[int]],
|
||||
aux_problem_sizes: list[list[int]],
|
||||
expert_offsets: list[int],
|
||||
aux_expert_offsets: list[int],
|
||||
input_blockscale_offsets: list[int],
|
||||
weight_blockscale_offsets: list[int],
|
||||
input_blockscale_offset: int,
|
||||
n_g: int,
|
||||
k_g: int,
|
||||
num_experts: int,
|
||||
expert_offset: int,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
device = input_tensor.device
|
||||
_problem_sizes = torch.tensor(problem_sizes).to(device=device, dtype=torch.int32)
|
||||
_aux_problem_sizes = torch.tensor(aux_problem_sizes).to(
|
||||
device=device, dtype=torch.int32
|
||||
)
|
||||
_expert_offsets = torch.tensor(expert_offsets).to(device=device, dtype=torch.int32)
|
||||
_aux_expert_offsets = torch.tensor(aux_expert_offsets).to(
|
||||
device=device, dtype=torch.int32
|
||||
)
|
||||
_input_blockscale_offsets = torch.tensor(input_blockscale_offsets).to(
|
||||
device=device, dtype=torch.int32
|
||||
)
|
||||
_weight_blockscale_offsets = torch.tensor(weight_blockscale_offsets).to(
|
||||
device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
input_quant = torch.zeros_like(
|
||||
input_tensor, dtype=torch.float8_e4m3fn, device=device
|
||||
)
|
||||
input_scale_factor = torch.zeros(
|
||||
(input_blockscale_offset, k_g // 32), dtype=torch.uint8, device=device
|
||||
)
|
||||
|
||||
weight_quant = torch.zeros_like(
|
||||
weight_tensor, dtype=torch.float8_e4m3fn, device=device
|
||||
)
|
||||
weight_scale_factor = torch.zeros(
|
||||
(num_experts, n_g, k_g // 32), dtype=torch.uint8, device=device
|
||||
)
|
||||
|
||||
ops.mxfp8_experts_quant(
|
||||
input_tensor,
|
||||
_problem_sizes,
|
||||
_expert_offsets,
|
||||
_input_blockscale_offsets,
|
||||
input_quant,
|
||||
input_scale_factor,
|
||||
)
|
||||
|
||||
ops.mxfp8_experts_quant(
|
||||
weight_tensor,
|
||||
_aux_problem_sizes,
|
||||
_aux_expert_offsets,
|
||||
_weight_blockscale_offsets,
|
||||
weight_quant,
|
||||
weight_scale_factor,
|
||||
)
|
||||
weight_quant = weight_quant.view(num_experts, n_g, k_g).transpose(1, 2)
|
||||
weight_scale_factor = weight_scale_factor.view(
|
||||
num_experts, n_g, k_g // 32
|
||||
).transpose(1, 2)
|
||||
|
||||
output = torch.empty((expert_offset, n_g), device=device, dtype=out_dtype)
|
||||
ops.cutlass_mxfp8_grouped_mm(
|
||||
input_quant,
|
||||
weight_quant,
|
||||
input_scale_factor,
|
||||
weight_scale_factor,
|
||||
output,
|
||||
_problem_sizes,
|
||||
_expert_offsets,
|
||||
_input_blockscale_offsets,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_sm100_supported(),
|
||||
reason=(
|
||||
"cutlass_mxfp8_grouped_mm and mxfp8_experts_quant "
|
||||
"are only supported on CUDA SM100"
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||
def test_cutlass_mxfp8_grouped_mm(num_experts, out_dtype):
|
||||
device = "cuda"
|
||||
alignment = 128
|
||||
n_g = random.randint(1, 64) * alignment
|
||||
k_g = random.randint(1, 64) * alignment
|
||||
|
||||
expert_offset = 0
|
||||
expert_offsets = []
|
||||
aux_expert_offset = 0
|
||||
aux_expert_offsets = []
|
||||
input_blockscale_offset = 0
|
||||
input_blockscale_offsets = []
|
||||
weight_blockscale_offset = 0
|
||||
weight_blockscale_offsets = []
|
||||
problem_sizes = []
|
||||
aux_problem_sizes = []
|
||||
input_list = []
|
||||
weight_list = []
|
||||
|
||||
for g in range(num_experts):
|
||||
m_g = random.randint(1, 512)
|
||||
expert_offsets.append(expert_offset)
|
||||
expert_offset += m_g
|
||||
aux_expert_offsets.append(aux_expert_offset)
|
||||
aux_expert_offset += n_g
|
||||
input_blockscale_offsets.append(input_blockscale_offset)
|
||||
input_blockscale_offset += align(m_g, 128)
|
||||
weight_blockscale_offsets.append(weight_blockscale_offset)
|
||||
weight_blockscale_offset += n_g # n_g already align to 128
|
||||
problem_sizes.append([m_g, n_g, k_g])
|
||||
aux_problem_sizes.append([n_g, m_g, k_g])
|
||||
|
||||
input_tensor = torch.normal(
|
||||
0.0, std=1.0, size=(m_g, k_g), device=device, dtype=out_dtype
|
||||
) # (M, K):(K, 1)
|
||||
weight_tensor = torch.normal(
|
||||
0.0, std=1.0, size=(n_g, k_g), device=device, dtype=out_dtype
|
||||
) # (N, K):(K, 1)
|
||||
|
||||
input_list.append(input_tensor)
|
||||
weight_list.append(weight_tensor)
|
||||
input_tensor = torch.concat(input_list, dim=0)
|
||||
weight_tensor = torch.concat(weight_list, dim=0)
|
||||
|
||||
ref_output = compute_ref_output(
|
||||
input_tensor=input_tensor,
|
||||
weight_list=weight_list,
|
||||
expert_offsets=expert_offsets,
|
||||
expert_offset=expert_offset,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
output = compute_kernel_output(
|
||||
input_tensor=input_tensor,
|
||||
weight_tensor=weight_tensor,
|
||||
problem_sizes=problem_sizes,
|
||||
aux_problem_sizes=aux_problem_sizes,
|
||||
expert_offsets=expert_offsets,
|
||||
aux_expert_offsets=aux_expert_offsets,
|
||||
input_blockscale_offsets=input_blockscale_offsets,
|
||||
weight_blockscale_offsets=weight_blockscale_offsets,
|
||||
input_blockscale_offset=input_blockscale_offset,
|
||||
n_g=n_g,
|
||||
k_g=k_g,
|
||||
num_experts=num_experts,
|
||||
expert_offset=expert_offset,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
baseline = ref_output[
|
||||
expert_offsets[g] : (expert_offsets[g] + problem_sizes[g][0])
|
||||
]
|
||||
actual = output[expert_offsets[g] : (expert_offsets[g] + problem_sizes[g][0])]
|
||||
diff = calc_diff(actual, baseline)
|
||||
assert diff < 0.001
|
||||
print(
|
||||
f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, "
|
||||
f"out_dtype={out_dtype}, diff={diff:.5f}: OK"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -1102,6 +1102,76 @@ def cutlass_fp4_moe_mm(
|
||||
)
|
||||
|
||||
|
||||
def mxfp8_experts_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
problem_sizes: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
quant_output: torch.Tensor,
|
||||
scale_factor: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C.mxfp8_experts_quant(
|
||||
input_tensor,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
quant_output,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_mxfp8_grouped_mm(
|
||||
a_tensors: torch.Tensor,
|
||||
b_tensors: torch.Tensor,
|
||||
a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
out_tensors: torch.Tensor,
|
||||
problem_sizes: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C.cutlass_mxfp8_grouped_mm(
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
a_scales,
|
||||
b_scales,
|
||||
out_tensors,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "mxfp8_experts_quant"):
|
||||
|
||||
@register_fake("_C::mxfp8_experts_quant")
|
||||
def _mxfp8_experts_quant_fake(
|
||||
input_tensor: torch.Tensor,
|
||||
problem_sizes: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
quant_output: torch.Tensor,
|
||||
scale_factor: torch.Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "cutlass_mxfp8_grouped_mm"):
|
||||
|
||||
@register_fake("_C::cutlass_mxfp8_grouped_mm")
|
||||
def _cutlass_mxfp8_grouped_mm_fake(
|
||||
a_tensors: torch.Tensor,
|
||||
b_tensors: torch.Tensor,
|
||||
a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
out_tensors: torch.Tensor,
|
||||
problem_sizes: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# gptq_marlin
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user