[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90(
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -56,6 +57,14 @@ void get_cutlass_moe_mm_data_caller(
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m,
|
||||
const int64_t n, const int64_t k);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
@@ -207,12 +216,13 @@ void cutlass_moe_mm(
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
@@ -245,6 +255,29 @@ void get_cutlass_moe_mm_data(
|
||||
version_num, ". Required capability: 90");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
|
||||
Reference in New Issue
Block a user