[Bug] correct out dtype of rms_norm_gated native path (#35369)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -577,7 +577,7 @@ class RMSNormGated(CustomOp):
|
||||
if z is not None and self.norm_before_gate:
|
||||
out = out * F.silu(z)
|
||||
|
||||
return out
|
||||
return out.to(x.dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self, x: torch.Tensor, z: torch.Tensor | None = None
|
||||
|
||||
Reference in New Issue
Block a user