[Bug] correct out dtype of rms_norm_gated native path (#35369)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
zofia
2026-02-27 13:19:51 +08:00
committed by GitHub
parent 487e5c51f7
commit 516cf26698

View File

@@ -577,7 +577,7 @@ class RMSNormGated(CustomOp):
if z is not None and self.norm_before_gate:
out = out * F.silu(z)
return out
return out.to(x.dtype)
def forward_cuda(
self, x: torch.Tensor, z: torch.Tensor | None = None