[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
blockscale_offsets, is_gated);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
|
||||
Reference in New Issue
Block a user