[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
This commit is contained in:
@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
device: torch.dtype,
|
||||
):
|
||||
assert quant_config.use_fp8_w8a8
|
||||
super().__init__(quant_config)
|
||||
|
||||
# E: num_experts
|
||||
# N: intermediate size per partition
|
||||
# K: hidden dim
|
||||
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
|
||||
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides1 = ab_strides1_c_strides2
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = c_strides2
|
||||
self.c_strides2 = ab_strides1_c_strides2
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
@@ -329,24 +337,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
|
||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
@@ -390,21 +380,10 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
num_dispatchers: int,
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
quant_config,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
assert max_experts_per_worker > 0
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.num_dispatchers = num_dispatchers
|
||||
@@ -445,113 +424,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
activation: str = "silu",
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The token->expert mappings.
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||
every Rank is responsible for a subset of experts. expert_map is a
|
||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||
is -1, it means that this Rank is not responsible for global
|
||||
expert-id i.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is 1.
|
||||
- global_num_experts (int): The total number of experts.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
assert quant_config is not None
|
||||
|
||||
if quant_config.a1_scale is not None:
|
||||
assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1)
|
||||
if quant_config.a2_scale is not None:
|
||||
assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1)
|
||||
|
||||
if quant_config.w1_scale is not None:
|
||||
if quant_config.per_out_ch_quant:
|
||||
assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size(
|
||||
1
|
||||
) == w1_q.size(1)
|
||||
else:
|
||||
assert (
|
||||
quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1
|
||||
)
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
out_dtype=a.dtype,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fn(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
Reference in New Issue
Block a user