[Perf] Fuse stride preparation for NVFP4 cutlass_moe (#31837)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-07 13:31:26 -05:00
committed by GitHub
parent 05f47bd8d2
commit f347ac6c34

View File

@@ -62,7 +62,9 @@ __global__ void __get_group_gemm_starts(
ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
const int K, const int N) {
int64_t* a_strides, int64_t* b_strides, int64_t* c_strides,
const int64_t a_stride_val, const int64_t b_stride_val,
const int64_t c_stride_val, const int K, const int N) {
int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
@@ -103,6 +105,11 @@ __global__ void __get_group_gemm_starts(
// Shape of alpha = [E]
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
// Initialize strides (constant across all experts, avoids separate kernels)
a_strides[expert_id] = a_stride_val;
b_strides[expert_id] = b_stride_val;
c_strides[expert_id] = c_stride_val;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
@@ -135,7 +142,11 @@ __global__ void __get_group_gemm_starts(
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
static_cast<int32_t*>(problem_sizes.data_ptr()), \
static_cast<int64_t*>(a_strides.data_ptr()), \
static_cast<int64_t*>(b_strides.data_ptr()), \
static_cast<int64_t*>(c_strides.data_ptr()), a_stride_val, \
b_stride_val, c_stride_val, K, N); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
@@ -144,6 +155,9 @@ void run_get_group_gemm_starts(
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
@@ -269,17 +283,16 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 =
torch::full({num_experts}, output.stride(0), options_int);
torch::Tensor a_strides1 =
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
expert_offsets, sf_offsets, problem_sizes, M, N, K);
layout_sfa, layout_sfb, a_strides1, b_strides1, c_strides1,
a.stride(0) * 2, b.stride(1) * 2, output.stride(0), a, b, output,
a_blockscale, b_blockscales, alphas, expert_offsets, sf_offsets,
problem_sizes, M, N, K);
// Create an instance of the GEMM
Gemm gemm_op;
@@ -444,17 +457,16 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 =
torch::full({num_experts}, output.stride(0), options_int);
torch::Tensor a_strides1 =
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
expert_offsets, sf_offsets, problem_sizes, M, N, K);
layout_sfa, layout_sfb, a_strides1, b_strides1, c_strides1,
a.stride(0) * 2, b.stride(1) * 2, output.stride(0), a, b, output,
a_blockscale, b_blockscales, alphas, expert_offsets, sf_offsets,
problem_sizes, M, N, K);
// Create an instance of the GEMM
Gemm gemm_op;