Fix MiniMax-M2 rmsnorm precision and remove useless code (#27627)
Signed-off-by: xuebi <xuebi@minimaxi.com> Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
@@ -77,7 +77,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
if self.tp_world > 1:
|
||||
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
x = (x * self.weight).to(orig_dtype)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user