[Bugfix][CI/Build] Fix failing Mteb CI (#26638)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -318,7 +318,11 @@ class GemmaRMSNorm(CustomOp):
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual.float() if orig_dtype == torch.float16 else x + residual
|
||||
x = (
|
||||
x.float() + residual.float()
|
||||
if orig_dtype == torch.float16
|
||||
else x + residual
|
||||
)
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
|
||||
Reference in New Issue
Block a user