[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)

This commit is contained in:
Robert Shaw
2026-01-07 19:42:33 -05:00
committed by GitHub
parent ffc0a2798b
commit 5dcd7ef1f2
38 changed files with 1439 additions and 1528 deletions

View File

@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
e: int,
n: int,
k: int,
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
device: torch.dtype,
):
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
# E: num_experts
# N: intermediate size per partition
# K: hidden dim
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.c_strides2 = ab_strides1_c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config,
)
@property
def activation_formats(
self,
@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
*args,
**kwargs,
):
super().__init__(
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config,
)
super().__init__(*args, **kwargs)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
return (workspace1, workspace2, output)
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert quant_config is not None
if quant_config.a1_scale is not None:
assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1)
if quant_config.w1_scale is not None:
if quant_config.per_out_ch_quant:
assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size(
1
) == w1_q.size(1)
else:
assert (
quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1
)
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
),
)
return fn(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max