[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
@@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
@@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
self.c_strides1, self.c_strides2, workspace13, workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format, topk_weights)
@@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
@property
@@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape,
quant_config,
)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
@@ -414,16 +395,12 @@ def cutlass_moe_fp8(
w2_q: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
@@ -475,10 +452,18 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
if per_act_token is None:
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.size(0)
assert quant_config is not None
if quant_config.a1_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a1_scale.numel() != 1)
if quant_config.a2_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a2_scale.numel() != 1)
assert (quant_config.w1_scale is None
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
== w1_q.size(1))))
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0)
@@ -487,12 +472,11 @@ def cutlass_moe_fp8(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
quant_config=quant_config,
),
)
@@ -502,14 +486,9 @@ def cutlass_moe_fp8(
w2_q,
topk_weights,
topk_ids,
False,
activation,
num_experts,
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
activation=activation,
global_num_experts=num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -542,7 +521,7 @@ def run_cutlass_moe_fp4(
) -> None:
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
@@ -552,16 +531,16 @@ def run_cutlass_moe_fp4(
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts, dtype: int
@@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
return
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
quant_config: FusedMoEQuantConfig,
use_batched_format: bool = False,
):
super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
super().__init__(quant_config)
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property
def activation_formats(
self
@@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
a1q_scale: Optional[torch.Tensor], # unused
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
a=hidden_states,
a1_gscale=self.a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_blockscale=self.w1_scale,
w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_blockscale=self.w2_scale,
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
@@ -788,14 +741,9 @@ def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
@@ -805,17 +753,31 @@ def cutlass_moe_fp4(
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
per_out_ch_quant=False,
quant_config=quant_config,
use_batched_format=False,
),
)
@@ -830,10 +792,6 @@ def cutlass_moe_fp4(
activation="silu",
global_num_experts=e,
expert_map=None,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
return True
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
def run_cutlass_block_scaled_fused_experts(
a: torch.Tensor,
w1: torch.Tensor,