Fix eplb nvfp4 experts hook (#37217)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com> Signed-off-by: Elvir Crncevic <elvir@anthropic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -659,6 +659,13 @@ def run_cutlass_moe_fp4(
|
||||
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
"""CUTLASS FP4 fused MoE expert implementation."""
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fuse activation scales into w_scale_2 in-place so that
|
||||
# g1/g2_alphas (which reference the same tensor) stay in sync
|
||||
# when EPLB rearranges the parameter.
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -56,10 +56,25 @@ class TrtLlmNvFp4ExpertsBase:
|
||||
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
self.g1_scale_c = (
|
||||
torch.ones_like(self.quant_config.a1_gscale)
|
||||
* self.quant_config.a2_gscale
|
||||
)
|
||||
self.g1_scale_c = self.quant_config.a2_gscale.clone()
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
# Recompute g1_scale_c since g1_alphas was just fused in-place.
|
||||
# Register as a layer parameter so EPLB rearranges it alongside
|
||||
# other expert weights.
|
||||
assert self.quant_config.g1_alphas is not None
|
||||
assert self.quant_config.a2_gscale is not None
|
||||
if self.moe_config.is_act_and_mul:
|
||||
g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
g1_scale_c = self.quant_config.a2_gscale.clone()
|
||||
layer.register_parameter(
|
||||
"g1_scale_c",
|
||||
torch.nn.Parameter(g1_scale_c, requires_grad=False),
|
||||
)
|
||||
self.g1_scale_c = layer.g1_scale_c
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
|
||||
@@ -49,6 +49,10 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
)
|
||||
self.out_dtype = moe_config.in_dtype
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
@@ -61,6 +61,11 @@ def is_valid_flashinfer_cutlass_fused_moe(
|
||||
|
||||
|
||||
class FlashInferExperts(mk.FusedMoEExpertsModular):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.quant_config.use_nvfp4_w4a4:
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: mk.FusedMoEConfig,
|
||||
|
||||
@@ -1421,19 +1421,23 @@ class FusedMoE(CustomOp):
|
||||
weights = list(self.named_parameters())
|
||||
weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
|
||||
|
||||
# `w13_input_scale` and `w2_input_scale` are global per-tensor
|
||||
# activation scales shared across all experts (e.g. NVFP4).
|
||||
# They are broadcast views (stride 0) from .expand() and are
|
||||
# not actual expert weights, so exclude them from EPLB.
|
||||
NON_EXPERT_WEIGHTS = {
|
||||
"e_score_correction_bias",
|
||||
"w13_input_scale",
|
||||
"w2_input_scale",
|
||||
}
|
||||
|
||||
assert all(
|
||||
weight.is_contiguous()
|
||||
for name, weight in weights
|
||||
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
|
||||
and name not in NON_EXPERT_WEIGHTS
|
||||
)
|
||||
|
||||
# Filter out the non-expert weights.
|
||||
# `e_score_correction_bias` is a bias for each logical expert,
|
||||
# with shape (num_logical_experts,), not an expert weight.
|
||||
NON_EXPERT_WEIGHTS = {
|
||||
"e_score_correction_bias",
|
||||
}
|
||||
|
||||
return [
|
||||
weight.view(self.local_num_experts, -1)
|
||||
for name, weight in weights
|
||||
|
||||
@@ -489,6 +489,9 @@ class FusedMoEExperts(ABC):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # noqa: B027
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
raise NotImplementedError("Implemented by subclasses.")
|
||||
|
||||
@@ -374,11 +374,13 @@ def make_nvfp4_moe_quant_config(
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
g1_alphas = a13_scale * w13_scale_2
|
||||
g2_alphas = a2_scale * w2_scale_2
|
||||
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
|
||||
# The expert's process_weights_after_loading will fuse activation
|
||||
# scales in-place. Since the quant config references the same tensor
|
||||
# as the registered parameter, EPLB rearrangement stays in sync.
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
a1_gscale=(1.0 / a13_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
w1_scale=w13_scale,
|
||||
|
||||
@@ -570,6 +570,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
self.moe_kernel.fused_experts.process_weights_after_loading(layer)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
|
||||
@@ -1394,6 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
self.moe_kernel.fused_experts.process_weights_after_loading(layer)
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return make_nvfp4_moe_quant_config(
|
||||
|
||||
@@ -267,16 +267,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
num_experts=w13.size(0),
|
||||
is_gated_activation=is_gated,
|
||||
)
|
||||
|
||||
# We do not need to make this a parameter, because
|
||||
# it is not used during the weight (re)-loading process.
|
||||
if is_gated:
|
||||
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
else:
|
||||
layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale
|
||||
layer.a1_gscale = 1.0 / a13_scale
|
||||
layer.g1_alphas = a13_scale * w13_scale_2
|
||||
layer.g2_alphas = a2_scale * w2_scale_2
|
||||
else:
|
||||
# Swizzle the block scales for other FI NVFP4 MoE kernels.
|
||||
w13_scale = swizzle_blockscale(w13_scale)
|
||||
|
||||
Reference in New Issue
Block a user