[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (#18500)

Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae>
Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
Dhia Eddine Rhaiem
2025-05-22 06:23:59 +04:00
committed by GitHub
parent 1f079540db
commit 20bd6f4d2e
2 changed files with 5 additions and 4 deletions

View File

@@ -453,7 +453,6 @@ class FalconH1Model(nn.Module):
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank: