[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert quant_config.use_fp8_w8a8
|
||||
super().__init__(quant_config)
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides2 = ab_strides2
|
||||
@@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
expert_num_tokens = None
|
||||
if expert_tokens_meta is not None:
|
||||
@@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
||||
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
||||
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
|
||||
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
|
||||
self.c_strides1, self.c_strides2, workspace13, workspace2,
|
||||
expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
use_batched_format, topk_weights)
|
||||
@@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
max_experts_per_worker: int,
|
||||
num_dispatchers: int,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
quant_config,
|
||||
)
|
||||
assert max_experts_per_worker > 0
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
@@ -414,16 +395,12 @@ def cutlass_moe_fp8(
|
||||
w2_q: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
per_act_token: Optional[bool] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
@@ -475,10 +452,18 @@ def cutlass_moe_fp8(
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
if per_act_token is None:
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
||||
assert quant_config is not None
|
||||
|
||||
if quant_config.a1_scale is not None:
|
||||
assert (quant_config.per_act_token_quant ==
|
||||
quant_config.a1_scale.numel() != 1)
|
||||
if quant_config.a2_scale is not None:
|
||||
assert (quant_config.per_act_token_quant ==
|
||||
quant_config.a2_scale.numel() != 1)
|
||||
|
||||
assert (quant_config.w1_scale is None
|
||||
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
|
||||
== w1_q.size(1))))
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
|
||||
0)
|
||||
@@ -487,12 +472,11 @@ def cutlass_moe_fp8(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -502,14 +486,9 @@ def cutlass_moe_fp8(
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
activation=activation,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@@ -542,7 +521,7 @@ def run_cutlass_moe_fp4(
|
||||
) -> None:
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
|
||||
# Gemm 1
|
||||
a: Input tensor: [m, k] (half/bfloat16)
|
||||
a1_gscale: Activation scale per expert: [e] (float32)
|
||||
@@ -552,16 +531,16 @@ def run_cutlass_moe_fp4(
|
||||
full precision)
|
||||
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
|
||||
(Block size = 16 for NVFP4)
|
||||
|
||||
|
||||
# Gemm 2
|
||||
a2_gscale: Activation scale per expert: [e]
|
||||
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
|
||||
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
|
||||
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
|
||||
|
||||
|
||||
topk_weights: [m, topk] dtype: float8
|
||||
topk_ids: [m, topk] dtype: float8
|
||||
|
||||
|
||||
m, n, k: Unquantized weight shapes, dtype: int
|
||||
e: number of experts, dtype: int
|
||||
|
||||
@@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
|
||||
return
|
||||
|
||||
|
||||
# Split into batched and non-batched
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
# NVFP4 requires two levels of quantization, which involves
|
||||
# computing some scaling factors dynamically. This makes it
|
||||
# incompatible with the typical prepare -> MoE -> finalize
|
||||
# pipeline. Move the quantization logic into the MoE body.
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=None, # skip quantization in prepare/finalize
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
super().__init__(quant_config)
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.out_dtype = out_dtype
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
# TODO(bnell): put this stuff into quant config?
|
||||
self.g1_alphas = g1_alphas
|
||||
self.g2_alphas = g2_alphas
|
||||
self.a1_gscale = a1_gscale
|
||||
self.a2_gscale = a2_gscale
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
@@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor], # unused
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a=hidden_states,
|
||||
a1_gscale=self.a1_gscale,
|
||||
w1_fp4=w1,
|
||||
w1_blockscale=w1_scale,
|
||||
w1_blockscale=self.w1_scale,
|
||||
w1_alphas=self.g1_alphas,
|
||||
a2_gscale=self.a2_gscale,
|
||||
w2_fp4=w2,
|
||||
w2_blockscale=w2_scale,
|
||||
w2_blockscale=self.w2_scale,
|
||||
w2_alphas=self.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
@@ -788,14 +741,9 @@ def cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
@@ -805,17 +753,31 @@ def cutlass_moe_fp4(
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
|
||||
|
||||
# TODO(bnell): this feels a bit hacky
|
||||
# NVFP4 requires two levels of quantization, which involves
|
||||
# computing some scaling factors dynamically. This makes it
|
||||
# incompatible with the typical prepare -> MoE -> finalize
|
||||
# pipeline. Move the quantization logic into the MoE body.
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=None, # skip quantization in prepare/finalize
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
per_out_ch_quant=quant_config.per_out_ch_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
g1_alphas=quant_config.g1_alphas,
|
||||
g2_alphas=quant_config.g2_alphas,
|
||||
a1_gscale=quant_config.a1_gscale,
|
||||
a2_gscale=quant_config.a2_gscale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp4(
|
||||
g1_alphas,
|
||||
g2_alphas,
|
||||
a1_gscale,
|
||||
a2_gscale,
|
||||
max_experts_per_worker=e,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
quant_config=quant_config,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
@@ -830,10 +792,6 @@ def cutlass_moe_fp4(
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
return True
|
||||
|
||||
|
||||
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
|
||||
def run_cutlass_block_scaled_fused_experts(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user