[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
|
||||
@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length, const int n,
|
||||
const int k) {
|
||||
const int k, const bool is_gated) {
|
||||
int expert_id = blockIdx.x;
|
||||
// For gated activations (gate + up), first GEMM output is 2*n.
|
||||
// For non-gated activations (up only), first GEMM output is n.
|
||||
int const n1 = is_gated ? 2 * n : n;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
if constexpr (!SWAP_AB) {
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 1] = n1;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
} else {
|
||||
problem_sizes1[expert_id * 3] = 2 * n;
|
||||
problem_sizes1[expert_id * 3] = n1;
|
||||
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = k;
|
||||
@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab) {
|
||||
inline void launch_compute_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab, const bool is_gated) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
static_cast<int>(k), is_gated);
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller(
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
may_swap_ab, is_gated);
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
// fp4 path
|
||||
|
||||
@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
blockscale_offsets, is_gated);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
|
||||
@@ -489,8 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k, Tensor? blockscale_offsets) -> "
|
||||
"()");
|
||||
" int n, int k, Tensor? blockscale_offsets, "
|
||||
" bool is_gated) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
|
||||
accuracy_threshold: 0.29
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass"
|
||||
@@ -15,3 +15,4 @@ Mixtral-8x7B-BF16-fi-cutlass.yaml
|
||||
Mixtral-8x7B-BF16-triton.yaml
|
||||
Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml
|
||||
Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml
|
||||
Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml
|
||||
|
||||
@@ -989,6 +989,7 @@ def get_cutlass_moe_mm_data(
|
||||
n: int,
|
||||
k: int,
|
||||
blockscale_offsets: torch.Tensor | None = None,
|
||||
is_gated: bool = True,
|
||||
):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
@@ -1012,6 +1013,8 @@ def get_cutlass_moe_mm_data(
|
||||
its computation. The number of block scale rows
|
||||
computed with expert E is blockscale_offsets[E + 1] -
|
||||
blockscale_offsets[E]
|
||||
- is_gated: Whether the activation is gated (gate + up). When True, the
|
||||
first GEMM N dimension is 2*n; when False, it is n.
|
||||
"""
|
||||
return torch.ops._C.get_cutlass_moe_mm_data(
|
||||
topk_ids,
|
||||
@@ -1024,6 +1027,7 @@ def get_cutlass_moe_mm_data(
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets,
|
||||
is_gated,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
|
||||
# Gemm 1
|
||||
a: Input tensor: [m, k] (half/bfloat16)
|
||||
a1_gscale: Activation scale per expert: [e] (float32)
|
||||
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
|
||||
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
|
||||
w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
|
||||
w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
|
||||
where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
|
||||
(Note: `n` is the up projection output dim, `k` is the input dim in
|
||||
full precision)
|
||||
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
|
||||
w1_blockscale: [e, w1_n, k // block_size] (float8_e4m3)
|
||||
(Block size = 16 for NVFP4)
|
||||
|
||||
# Gemm 2
|
||||
@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
|
||||
|
||||
assumes that topk < k < n to satisfy - up/down projection expectations.
|
||||
"""
|
||||
is_gated = activation.is_gated
|
||||
# For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
|
||||
# For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
|
||||
w1_n = n * 2 if is_gated else n
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
|
||||
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
|
||||
@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
|
||||
and w2_blockscale.ndim == 3
|
||||
), "All Weights must be of rank 3 for cutlass_moe_fp4"
|
||||
m_a, k_a = a.shape
|
||||
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
||||
e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
|
||||
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
||||
|
||||
assert e_w1 == e_w2 and e_w1 == e, (
|
||||
@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
|
||||
assert k_a == half_k_w1 * 2 and k == k_w2, (
|
||||
"Hidden size mismatch between a, w1 and w2"
|
||||
)
|
||||
assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`"
|
||||
assert w1_n_actual == w1_n and half_n_w2 * 2 == n, "mismatch in expected `n`"
|
||||
assert m == m_a, "input shape mismatch"
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets,
|
||||
is_gated=is_gated,
|
||||
)
|
||||
|
||||
a = ops.shuffle_rows(a, a_map)
|
||||
@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
)
|
||||
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||
c1 = _resize_cache(workspace13, (m * topk, w1_n))
|
||||
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
@@ -681,7 +688,7 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
@@ -695,11 +702,16 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
# SILU uses a fused silu+mul+fp4_quant kernel path.
|
||||
# Other gated activations use the generic apply_moe_activation()
|
||||
# fallback + separate fp4 quantization in run_cutlass_moe_fp4().
|
||||
# Non-gated activations (_NO_MUL) are also supported for models
|
||||
# like Nemotron-Nano that don't use gated MLP.
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SWIGLUSTEP,
|
||||
MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user