Fix MiniMax-M2 rmsnorm precision and remove useless code (#27627)

Signed-off-by: xuebi <xuebi@minimaxi.com>
Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
Roger Young
2025-10-29 21:01:05 +08:00
committed by GitHub
parent ecca3fee76
commit d6704dd099
2 changed files with 1 additions and 19 deletions

View File

@@ -77,7 +77,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
x = (x * self.weight).to(orig_dtype)
return x
def forward(