Fix eplb nvfp4 experts hook (#37217)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Elvir Crncevic <elvir@anthropic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Elvir Crnčević
2026-03-16 23:03:54 +01:00
committed by GitHub
parent c0f011918d
commit fd4d96302a
10 changed files with 57 additions and 25 deletions

View File

@@ -659,6 +659,13 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
"""CUTLASS FP4 fused MoE expert implementation."""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fuse activation scales into w_scale_2 in-place so that
# g1/g2_alphas (which reference the same tensor) stay in sync
# when EPLB rearranges the parameter.
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
@property
def expects_unquantized_inputs(self) -> bool:
return True

View File

@@ -56,10 +56,25 @@ class TrtLlmNvFp4ExpertsBase:
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
else:
self.g1_scale_c = (
torch.ones_like(self.quant_config.a1_gscale)
* self.quant_config.a2_gscale
)
self.g1_scale_c = self.quant_config.a2_gscale.clone()
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
# Recompute g1_scale_c since g1_alphas was just fused in-place.
# Register as a layer parameter so EPLB rearranges it alongside
# other expert weights.
assert self.quant_config.g1_alphas is not None
assert self.quant_config.a2_gscale is not None
if self.moe_config.is_act_and_mul:
g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
else:
g1_scale_c = self.quant_config.a2_gscale.clone()
layer.register_parameter(
"g1_scale_c",
torch.nn.Parameter(g1_scale_c, requires_grad=False),
)
self.g1_scale_c = layer.g1_scale_c
@staticmethod
def _supports_current_device() -> bool:

View File

@@ -49,6 +49,10 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
)
self.out_dtype = moe_config.in_dtype
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts

View File

@@ -61,6 +61,11 @@ def is_valid_flashinfer_cutlass_fused_moe(
class FlashInferExperts(mk.FusedMoEExpertsModular):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.quant_config.use_nvfp4_w4a4:
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
def __init__(
self,
moe_config: mk.FusedMoEConfig,

View File

@@ -1421,19 +1421,23 @@ class FusedMoE(CustomOp):
weights = list(self.named_parameters())
weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
# `w13_input_scale` and `w2_input_scale` are global per-tensor
# activation scales shared across all experts (e.g. NVFP4).
# They are broadcast views (stride 0) from .expand() and are
# not actual expert weights, so exclude them from EPLB.
NON_EXPERT_WEIGHTS = {
"e_score_correction_bias",
"w13_input_scale",
"w2_input_scale",
}
assert all(
weight.is_contiguous()
for name, weight in weights
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
and name not in NON_EXPERT_WEIGHTS
)
# Filter out the non-expert weights.
# `e_score_correction_bias` is a bias for each logical expert,
# with shape (num_logical_experts,), not an expert weight.
NON_EXPERT_WEIGHTS = {
"e_score_correction_bias",
}
return [
weight.view(self.local_num_experts, -1)
for name, weight in weights

View File

@@ -489,6 +489,9 @@ class FusedMoEExperts(ABC):
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # noqa: B027
pass
@staticmethod
def is_monolithic() -> bool:
raise NotImplementedError("Implemented by subclasses.")

View File

@@ -374,11 +374,13 @@ def make_nvfp4_moe_quant_config(
w2_scale=w2_scale,
)
g1_alphas = a13_scale * w13_scale_2
g2_alphas = a2_scale * w2_scale_2
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
# The expert's process_weights_after_loading will fuse activation
# scales in-place. Since the quant config references the same tensor
# as the registered parameter, EPLB rearrangement stays in sync.
return nvfp4_moe_quant_config(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
a1_gscale=(1.0 / a13_scale),
a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale,

View File

@@ -570,6 +570,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
self.moe_kernel.fused_experts.process_weights_after_loading(layer)
def maybe_make_prepare_finalize(
self,

View File

@@ -1394,6 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
self.moe_kernel.fused_experts.process_weights_after_loading(layer)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(

View File

@@ -267,16 +267,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
num_experts=w13.size(0),
is_gated_activation=is_gated,
)
# We do not need to make this a parameter, because
# it is not used during the weight (re)-loading process.
if is_gated:
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
else:
layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale
layer.a1_gscale = 1.0 / a13_scale
layer.g1_alphas = a13_scale * w13_scale_2
layer.g2_alphas = a2_scale * w2_scale_2
else:
# Swizzle the block scales for other FI NVFP4 MoE kernels.
w13_scale = swizzle_blockscale(w13_scale)