[Perf] Fuse stride preparation for NVFP4 cutlass_moe (#31837)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -62,7 +62,9 @@ __global__ void __get_group_gemm_starts(
|
||||
ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
|
||||
ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
|
||||
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
|
||||
const int K, const int N) {
|
||||
int64_t* a_strides, int64_t* b_strides, int64_t* c_strides,
|
||||
const int64_t a_stride_val, const int64_t b_stride_val,
|
||||
const int64_t c_stride_val, const int K, const int N) {
|
||||
int64_t expert_id = threadIdx.x;
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
@@ -103,6 +105,11 @@ __global__ void __get_group_gemm_starts(
|
||||
// Shape of alpha = [E]
|
||||
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
|
||||
|
||||
// Initialize strides (constant across all experts, avoids separate kernels)
|
||||
a_strides[expert_id] = a_stride_val;
|
||||
b_strides[expert_id] = b_stride_val;
|
||||
c_strides[expert_id] = c_stride_val;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
@@ -135,7 +142,11 @@ __global__ void __get_group_gemm_starts(
|
||||
static_cast<float*>(alphas.data_ptr()), \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(sf_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
|
||||
static_cast<int32_t*>(problem_sizes.data_ptr()), \
|
||||
static_cast<int64_t*>(a_strides.data_ptr()), \
|
||||
static_cast<int64_t*>(b_strides.data_ptr()), \
|
||||
static_cast<int64_t*>(c_strides.data_ptr()), a_stride_val, \
|
||||
b_stride_val, c_stride_val, K, N); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
@@ -144,6 +155,9 @@ void run_get_group_gemm_starts(
|
||||
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
|
||||
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
|
||||
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
|
||||
int64_t c_stride_val,
|
||||
/*these are used for their base addresses*/
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
|
||||
@@ -269,17 +283,16 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor c_strides1 =
|
||||
torch::full({num_experts}, output.stride(0), options_int);
|
||||
torch::Tensor a_strides1 =
|
||||
torch::full({num_experts}, a.stride(0) * 2, options_int);
|
||||
torch::Tensor b_strides1 =
|
||||
torch::full({num_experts}, b.stride(1) * 2, options_int);
|
||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
|
||||
expert_offsets, sf_offsets, problem_sizes, M, N, K);
|
||||
layout_sfa, layout_sfb, a_strides1, b_strides1, c_strides1,
|
||||
a.stride(0) * 2, b.stride(1) * 2, output.stride(0), a, b, output,
|
||||
a_blockscale, b_blockscales, alphas, expert_offsets, sf_offsets,
|
||||
problem_sizes, M, N, K);
|
||||
|
||||
// Create an instance of the GEMM
|
||||
Gemm gemm_op;
|
||||
@@ -444,17 +457,16 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor c_strides1 =
|
||||
torch::full({num_experts}, output.stride(0), options_int);
|
||||
torch::Tensor a_strides1 =
|
||||
torch::full({num_experts}, a.stride(0) * 2, options_int);
|
||||
torch::Tensor b_strides1 =
|
||||
torch::full({num_experts}, b.stride(1) * 2, options_int);
|
||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
|
||||
expert_offsets, sf_offsets, problem_sizes, M, N, K);
|
||||
layout_sfa, layout_sfb, a_strides1, b_strides1, c_strides1,
|
||||
a.stride(0) * 2, b.stride(1) * 2, output.stride(0), a, b, output,
|
||||
a_blockscale, b_blockscales, alphas, expert_offsets, sf_offsets,
|
||||
problem_sizes, M, N, K);
|
||||
|
||||
// Create an instance of the GEMM
|
||||
Gemm gemm_op;
|
||||
|
||||
Reference in New Issue
Block a user