[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length, const int n,
|
||||
const int k) {
|
||||
const int k, const bool is_gated) {
|
||||
int expert_id = blockIdx.x;
|
||||
// For gated activations (gate + up), first GEMM output is 2*n.
|
||||
// For non-gated activations (up only), first GEMM output is n.
|
||||
int const n1 = is_gated ? 2 * n : n;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
if constexpr (!SWAP_AB) {
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 1] = n1;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
} else {
|
||||
problem_sizes1[expert_id * 3] = 2 * n;
|
||||
problem_sizes1[expert_id * 3] = n1;
|
||||
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = k;
|
||||
@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab) {
|
||||
inline void launch_compute_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab, const bool is_gated) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
static_cast<int>(k), is_gated);
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller(
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
may_swap_ab, is_gated);
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
// fp4 path
|
||||
|
||||
Reference in New Issue
Block a user