[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
@@ -1052,3 +1052,404 @@ def run_cutlass_block_scaled_fused_experts(
|
||||
return (
|
||||
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
# W4A8
|
||||
def run_cutlass_moe_w4a8_fp8(
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation_callable: Callable,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
w1_chan_scale: torch.Tensor,
|
||||
w2_chan_scale: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
a_strides2: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
b_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides1: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor | None,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_batched_format: bool,
|
||||
topk_weights: torch.Tensor | None,
|
||||
group_size: int,
|
||||
):
|
||||
a1q = hidden_states
|
||||
M = a1q.size(0)
|
||||
local_E = w1.size(0)
|
||||
device = a1q.device
|
||||
_, K, N_packed = w2.shape
|
||||
N = N_packed * 8 # logical N, pack 8 int4 into 1 int32
|
||||
|
||||
assert per_act_token, "W4A8 must use per-token scales"
|
||||
assert per_out_ch, "W4A8 must use per-channel scales"
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
assert w1_scale.dtype == torch.float8_e4m3fn
|
||||
assert w2_scale.dtype == torch.float8_e4m3fn
|
||||
assert w1.dtype == torch.int32
|
||||
assert w2.dtype == torch.int32
|
||||
assert w1_chan_scale.dtype == torch.float32
|
||||
assert w2_chan_scale.dtype == torch.float32
|
||||
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
|
||||
assert a1q_scale is not None
|
||||
assert a2_scale is None
|
||||
assert out_dtype in [torch.bfloat16], f"Invalid output dtype: {out_dtype}"
|
||||
if expert_map is not None:
|
||||
assert expert_num_tokens is None
|
||||
assert not use_batched_format, "batched format not supported yet"
|
||||
assert group_size == 128, f"Only group size 128 supported but got {group_size=}"
|
||||
|
||||
assert global_num_experts != -1
|
||||
assert w1.size(2) * 8 == K, (
|
||||
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
|
||||
)
|
||||
|
||||
# Translate info from expert_map to topk_ids
|
||||
if expert_map is not None:
|
||||
local_topk_ids = torch.where(
|
||||
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
|
||||
)
|
||||
else:
|
||||
local_topk_ids = topk_ids
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
|
||||
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
act_out = _resize_cache(workspace2, (M * topk, N))
|
||||
# original workspace are based on input hidden_states dtype (bf16)
|
||||
quant_out = _resize_cache(
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N)
|
||||
)
|
||||
mm2_out = _resize_cache(workspace2, (M * topk, K))
|
||||
|
||||
problem_sizes1 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||
# permuted a1q reuses workspace2
|
||||
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
topk_ids,
|
||||
num_expert,
|
||||
local_E,
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm,
|
||||
)
|
||||
expert_offsets = expert_offsets[:-1]
|
||||
|
||||
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
|
||||
ops.get_cutlass_moe_mm_problem_sizes(
|
||||
local_topk_ids,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
global_num_experts,
|
||||
N,
|
||||
K,
|
||||
force_swap_ab=True,
|
||||
)
|
||||
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
mm1_out,
|
||||
a1q,
|
||||
w1,
|
||||
a1q_scale,
|
||||
w1_chan_scale,
|
||||
w1_scale,
|
||||
group_size,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
s_strides1,
|
||||
)
|
||||
|
||||
activation_callable(act_out, mm1_out)
|
||||
|
||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
mm2_out.fill_(0)
|
||||
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
mm2_out,
|
||||
a2q,
|
||||
w2,
|
||||
a2q_scale,
|
||||
w2_chan_scale,
|
||||
w2_scale,
|
||||
group_size,
|
||||
expert_offsets,
|
||||
problem_sizes2,
|
||||
a_strides2,
|
||||
b_strides2,
|
||||
c_strides2,
|
||||
s_strides2,
|
||||
)
|
||||
|
||||
# for non-chunking mode the output is resized from workspace13
|
||||
# so we need to make sure mm2_out uses workspace2.
|
||||
moe_unpermute(
|
||||
out=output,
|
||||
permuted_hidden_states=mm2_out,
|
||||
topk_weights=topk_weights,
|
||||
inv_permuted_idx=inv_perm,
|
||||
)
|
||||
|
||||
|
||||
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype | None,
|
||||
a_strides1: torch.Tensor,
|
||||
a_strides2: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
b_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides1: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
group_size: int,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
self.out_dtype = out_dtype
|
||||
self.a_strides1 = a_strides1
|
||||
self.a_strides2 = a_strides2
|
||||
self.b_strides1 = b_strides1
|
||||
self.b_strides2 = b_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = c_strides2
|
||||
self.s_strides1 = s_strides1
|
||||
self.s_strides2 = s_strides2
|
||||
self.group_size = group_size
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard,
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return self.out_dtype if self.out_dtype is not None else act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
expert_num_tokens = None
|
||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||
|
||||
use_batched_format = (
|
||||
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
)
|
||||
assert not use_batched_format, "batched format not supported"
|
||||
|
||||
in_dtype = hidden_states.dtype
|
||||
|
||||
run_cutlass_moe_w4a8_fp8(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation_callable,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
self.w1_scale,
|
||||
self.w2_scale,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
self.g1_alphas, # per-channel scales
|
||||
self.g2_alphas, # per-channel scales
|
||||
self.a_strides1,
|
||||
self.a_strides2,
|
||||
self.b_strides1,
|
||||
self.b_strides2,
|
||||
self.c_strides1,
|
||||
self.c_strides2,
|
||||
self.s_strides1,
|
||||
self.s_strides2,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant,
|
||||
self.per_out_ch_quant,
|
||||
use_batched_format,
|
||||
topk_weights,
|
||||
self.group_size,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_moe_w4a8_fp8(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
a_strides2: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
b_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides1: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
activation: str = "silu",
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
group_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
mixed-dtype grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, 2*N, K // packed_factor]
|
||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, K, N // packed_factor]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The token->expert mappings.
|
||||
- a_strides1 (torch.Tensor): The input strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- a_strides2 (torch.Tensor): The input strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- b_strides1 (torch.Tensor): The packed layout for the first gemm weights.
|
||||
Shape: [num_experts, 3]
|
||||
dtype: torch.int32
|
||||
- b_strides2 (torch.Tensor): The packed layout for the second gemm weights.
|
||||
Shape: [num_experts, 3]
|
||||
dtype: torch.int32
|
||||
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm.
|
||||
Shape: [num_experts, 2]
|
||||
dtype: torch.int64
|
||||
- s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm.
|
||||
Shape: [num_experts, 2]
|
||||
dtype: torch.int64
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||
every Rank is responsible for a subset of experts. expert_map is a
|
||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||
is -1, it means that this Rank is not responsible for global
|
||||
expert-id i.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is 1.
|
||||
- global_num_experts (int): The total number of experts.
|
||||
- group_size (int): The number of weights per scale factor
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The bf16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
assert quant_config is not None
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsW4A8Fp8(
|
||||
out_dtype=a.dtype,
|
||||
a_strides1=a_strides1,
|
||||
a_strides2=a_strides2,
|
||||
b_strides1=b_strides1,
|
||||
b_strides2=b_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
s_strides1=s_strides1,
|
||||
s_strides2=s_strides2,
|
||||
quant_config=quant_config,
|
||||
group_size=group_size,
|
||||
),
|
||||
)
|
||||
|
||||
return fn(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user