[Kernels] Split up fused_moe/layer.py, isolate more modular kernel code (#28064)

This commit is contained in:
bnellnm
2025-11-11 09:29:02 -05:00
committed by GitHub
parent fa1970201d
commit a1448b4b69
10 changed files with 1064 additions and 948 deletions

View File

@@ -741,15 +741,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
self.w13_weight = w13_weight
self.w2_weight = w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
@@ -824,18 +819,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"EP batched experts format"
)
else:
layer.w13_weight = (
self.w13_weight_triton_tensor
if layer.w13_weight is None
else layer.w13_weight
)
layer.w2_weight = (
self.w2_weight_triton_tensor
if layer.w2_weight is None
else layer.w2_weight
)
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
@@ -1070,8 +1053,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
w1=self.w13_weight,
w2=self.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,