[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (#18500)

Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae>
Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
Dhia Eddine Rhaiem
2025-05-22 06:23:59 +04:00
committed by GitHub
parent 1f079540db
commit 20bd6f4d2e
2 changed files with 5 additions and 4 deletions

View File

@@ -77,7 +77,7 @@ class Mixer2RMSNormGated(CustomOp):
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
@@ -117,9 +117,11 @@ class Mixer2RMSNormGated(CustomOp):
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
return x * nn.functional.silu(gate.to(torch.float32))
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)