[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-12-08 22:29:06 -05:00
committed by GitHub
parent ea657f2078
commit f6227c22ab
22 changed files with 2045 additions and 101 deletions

View File

@@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
// Swap-AB should be disabled for FP4 path
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
bool may_swap_ab =
force_swap_ab.value_or((!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD));
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream,