[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-12-08 22:29:06 -05:00
committed by GitHub
parent ea657f2078
commit f6227c22ab
22 changed files with 2045 additions and 101 deletions

View File

@@ -1052,3 +1052,404 @@ def run_cutlass_block_scaled_fused_experts(
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)
# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation_callable: Callable,
global_num_experts: int,
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
w1_chan_scale: torch.Tensor,
w2_chan_scale: torch.Tensor,
a_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides1: torch.Tensor,
b_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: torch.Tensor | None,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool,
topk_weights: torch.Tensor | None,
group_size: int,
):
a1q = hidden_states
M = a1q.size(0)
local_E = w1.size(0)
device = a1q.device
_, K, N_packed = w2.shape
N = N_packed * 8 # logical N, pack 8 int4 into 1 int32
assert per_act_token, "W4A8 must use per-token scales"
assert per_out_ch, "W4A8 must use per-channel scales"
assert w1_scale is not None
assert w2_scale is not None
assert w1_scale.dtype == torch.float8_e4m3fn
assert w2_scale.dtype == torch.float8_e4m3fn
assert w1.dtype == torch.int32
assert w2.dtype == torch.int32
assert w1_chan_scale.dtype == torch.float32
assert w2_chan_scale.dtype == torch.float32
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
assert a1q_scale is not None
assert a2_scale is None
assert out_dtype in [torch.bfloat16], f"Invalid output dtype: {out_dtype}"
if expert_map is not None:
assert expert_num_tokens is None
assert not use_batched_format, "batched format not supported yet"
assert group_size == 128, f"Only group size 128 supported but got {group_size=}"
assert global_num_experts != -1
assert w1.size(2) * 8 == K, (
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
)
# Translate info from expert_map to topk_ids
if expert_map is not None:
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.size(1)
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N)
)
mm2_out = _resize_cache(workspace2, (M * topk, K))
problem_sizes1 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a1q,
a1q_scale,
topk_ids,
num_expert,
local_E,
expert_map,
permuted_hidden_states=a1q_perm,
)
expert_offsets = expert_offsets[:-1]
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids,
problem_sizes1,
problem_sizes2,
global_num_experts,
N,
K,
force_swap_ab=True,
)
ops.cutlass_w4a8_moe_mm(
mm1_out,
a1q,
w1,
a1q_scale,
w1_chan_scale,
w1_scale,
group_size,
expert_offsets,
problem_sizes1,
a_strides1,
b_strides1,
c_strides1,
s_strides1,
)
activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
)
if expert_map is not None:
mm2_out.fill_(0)
ops.cutlass_w4a8_moe_mm(
mm2_out,
a2q,
w2,
a2q_scale,
w2_chan_scale,
w2_scale,
group_size,
expert_offsets,
problem_sizes2,
a_strides2,
b_strides2,
c_strides2,
s_strides2,
)
# for non-chunking mode the output is resized from workspace13
# so we need to make sure mm2_out uses workspace2.
moe_unpermute(
out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm,
)
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: torch.dtype | None,
a_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides1: torch.Tensor,
b_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
group_size: int,
):
super().__init__(quant_config)
self.out_dtype = out_dtype
self.a_strides1 = a_strides1
self.a_strides2 = a_strides2
self.b_strides1 = b_strides1
self.b_strides2 = b_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.s_strides1 = s_strides1
self.s_strides2 = s_strides2
self.group_size = group_size
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
)
assert not use_batched_format, "batched format not supported"
in_dtype = hidden_states.dtype
run_cutlass_moe_w4a8_fp8(
output,
hidden_states,
w1,
w2,
topk_ids,
activation_callable,
global_num_experts,
expert_map,
self.w1_scale,
self.w2_scale,
a1q_scale,
a2_scale,
self.g1_alphas, # per-channel scales
self.g2_alphas, # per-channel scales
self.a_strides1,
self.a_strides2,
self.b_strides1,
self.b_strides2,
self.c_strides1,
self.c_strides2,
self.s_strides1,
self.s_strides2,
workspace13,
workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant,
self.per_out_ch_quant,
use_batched_format,
topk_weights,
self.group_size,
)
def cutlass_moe_w4a8_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides1: torch.Tensor,
b_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
group_size: int = 128,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
mixed-dtype grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, 2*N, K // packed_factor]
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, K, N // packed_factor]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- a_strides1 (torch.Tensor): The input strides for the first gemm.
Shape: [num_experts]
- a_strides2 (torch.Tensor): The input strides for the second gemm.
Shape: [num_experts]
- b_strides1 (torch.Tensor): The packed layout for the first gemm weights.
Shape: [num_experts, 3]
dtype: torch.int32
- b_strides2 (torch.Tensor): The packed layout for the second gemm weights.
Shape: [num_experts, 3]
dtype: torch.int32
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm.
Shape: [num_experts, 2]
dtype: torch.int64
- s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm.
Shape: [num_experts, 2]
dtype: torch.int64
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
- group_size (int): The number of weights per scale factor
Returns:
- torch.Tensor: The bf16 output tensor after applying the MoE layer.
"""
assert quant_config is not None
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsW4A8Fp8(
out_dtype=a.dtype,
a_strides1=a_strides1,
a_strides2=a_strides2,
b_strides1=b_strides1,
b_strides2=b_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
s_strides1=s_strides1,
s_strides2=s_strides2,
quant_config=quant_config,
group_size=group_size,
),
)
return fn(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)