Fix eplb nvfp4 experts hook (#37217)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Elvir Crncevic <elvir@anthropic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
(cherry picked from commit fd4d96302a)
This commit is contained in:
@@ -659,6 +659,13 @@ def run_cutlass_moe_fp4(
|
|||||||
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||||
"""CUTLASS FP4 fused MoE expert implementation."""
|
"""CUTLASS FP4 fused MoE expert implementation."""
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# Fuse activation scales into w_scale_2 in-place so that
|
||||||
|
# g1/g2_alphas (which reference the same tensor) stay in sync
|
||||||
|
# when EPLB rearranges the parameter.
|
||||||
|
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||||
|
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expects_unquantized_inputs(self) -> bool:
|
def expects_unquantized_inputs(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -56,10 +56,25 @@ class TrtLlmNvFp4ExpertsBase:
|
|||||||
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||||
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||||
else:
|
else:
|
||||||
self.g1_scale_c = (
|
self.g1_scale_c = self.quant_config.a2_gscale.clone()
|
||||||
torch.ones_like(self.quant_config.a1_gscale)
|
|
||||||
* self.quant_config.a2_gscale
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
)
|
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||||
|
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||||
|
# Recompute g1_scale_c since g1_alphas was just fused in-place.
|
||||||
|
# Register as a layer parameter so EPLB rearranges it alongside
|
||||||
|
# other expert weights.
|
||||||
|
assert self.quant_config.g1_alphas is not None
|
||||||
|
assert self.quant_config.a2_gscale is not None
|
||||||
|
if self.moe_config.is_act_and_mul:
|
||||||
|
g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||||
|
else:
|
||||||
|
g1_scale_c = self.quant_config.a2_gscale.clone()
|
||||||
|
layer.register_parameter(
|
||||||
|
"g1_scale_c",
|
||||||
|
torch.nn.Parameter(g1_scale_c, requires_grad=False),
|
||||||
|
)
|
||||||
|
self.g1_scale_c = layer.g1_scale_c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_current_device() -> bool:
|
def _supports_current_device() -> bool:
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
|||||||
)
|
)
|
||||||
self.out_dtype = moe_config.in_dtype
|
self.out_dtype = moe_config.in_dtype
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||||
|
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
|||||||
@@ -61,6 +61,11 @@ def is_valid_flashinfer_cutlass_fused_moe(
|
|||||||
|
|
||||||
|
|
||||||
class FlashInferExperts(mk.FusedMoEExpertsModular):
|
class FlashInferExperts(mk.FusedMoEExpertsModular):
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if self.quant_config.use_nvfp4_w4a4:
|
||||||
|
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||||
|
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
moe_config: mk.FusedMoEConfig,
|
moe_config: mk.FusedMoEConfig,
|
||||||
|
|||||||
@@ -1421,19 +1421,23 @@ class FusedMoE(CustomOp):
|
|||||||
weights = list(self.named_parameters())
|
weights = list(self.named_parameters())
|
||||||
weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
|
weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
|
||||||
|
|
||||||
|
# `w13_input_scale` and `w2_input_scale` are global per-tensor
|
||||||
|
# activation scales shared across all experts (e.g. NVFP4).
|
||||||
|
# They are broadcast views (stride 0) from .expand() and are
|
||||||
|
# not actual expert weights, so exclude them from EPLB.
|
||||||
|
NON_EXPERT_WEIGHTS = {
|
||||||
|
"e_score_correction_bias",
|
||||||
|
"w13_input_scale",
|
||||||
|
"w2_input_scale",
|
||||||
|
}
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
weight.is_contiguous()
|
weight.is_contiguous()
|
||||||
for name, weight in weights
|
for name, weight in weights
|
||||||
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
|
if not (name.startswith("_shared_experts.") or name.startswith("_gate."))
|
||||||
|
and name not in NON_EXPERT_WEIGHTS
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter out the non-expert weights.
|
|
||||||
# `e_score_correction_bias` is a bias for each logical expert,
|
|
||||||
# with shape (num_logical_experts,), not an expert weight.
|
|
||||||
NON_EXPERT_WEIGHTS = {
|
|
||||||
"e_score_correction_bias",
|
|
||||||
}
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
weight.view(self.local_num_experts, -1)
|
weight.view(self.local_num_experts, -1)
|
||||||
for name, weight in weights
|
for name, weight in weights
|
||||||
|
|||||||
@@ -489,6 +489,9 @@ class FusedMoEExperts(ABC):
|
|||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # noqa: B027
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_monolithic() -> bool:
|
def is_monolithic() -> bool:
|
||||||
raise NotImplementedError("Implemented by subclasses.")
|
raise NotImplementedError("Implemented by subclasses.")
|
||||||
|
|||||||
@@ -374,11 +374,13 @@ def make_nvfp4_moe_quant_config(
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
g1_alphas = a13_scale * w13_scale_2
|
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
|
||||||
g2_alphas = a2_scale * w2_scale_2
|
# The expert's process_weights_after_loading will fuse activation
|
||||||
|
# scales in-place. Since the quant config references the same tensor
|
||||||
|
# as the registered parameter, EPLB rearrangement stays in sync.
|
||||||
return nvfp4_moe_quant_config(
|
return nvfp4_moe_quant_config(
|
||||||
g1_alphas=g1_alphas,
|
g1_alphas=w13_scale_2,
|
||||||
g2_alphas=g2_alphas,
|
g2_alphas=w2_scale_2,
|
||||||
a1_gscale=(1.0 / a13_scale),
|
a1_gscale=(1.0 / a13_scale),
|
||||||
a2_gscale=(1.0 / a2_scale),
|
a2_gscale=(1.0 / a2_scale),
|
||||||
w1_scale=w13_scale,
|
w1_scale=w13_scale,
|
||||||
|
|||||||
@@ -570,6 +570,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
shared_experts=layer.shared_experts,
|
shared_experts=layer.shared_experts,
|
||||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||||
)
|
)
|
||||||
|
self.moe_kernel.fused_experts.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1394,6 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
shared_experts=layer.shared_experts,
|
shared_experts=layer.shared_experts,
|
||||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||||
)
|
)
|
||||||
|
self.moe_kernel.fused_experts.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||||
return make_nvfp4_moe_quant_config(
|
return make_nvfp4_moe_quant_config(
|
||||||
|
|||||||
@@ -267,16 +267,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
|||||||
num_experts=w13.size(0),
|
num_experts=w13.size(0),
|
||||||
is_gated_activation=is_gated,
|
is_gated_activation=is_gated,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do not need to make this a parameter, because
|
|
||||||
# it is not used during the weight (re)-loading process.
|
|
||||||
if is_gated:
|
|
||||||
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
|
||||||
else:
|
|
||||||
layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale
|
|
||||||
layer.a1_gscale = 1.0 / a13_scale
|
|
||||||
layer.g1_alphas = a13_scale * w13_scale_2
|
|
||||||
layer.g2_alphas = a2_scale * w2_scale_2
|
|
||||||
else:
|
else:
|
||||||
# Swizzle the block scales for other FI NVFP4 MoE kernels.
|
# Swizzle the block scales for other FI NVFP4 MoE kernels.
|
||||||
w13_scale = swizzle_blockscale(w13_scale)
|
w13_scale = swizzle_blockscale(w13_scale)
|
||||||
|
|||||||
Reference in New Issue
Block a user