[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (#18500)
Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae> Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
committed by
GitHub
parent
1f079540db
commit
20bd6f4d2e
@@ -453,7 +453,6 @@ class FalconH1Model(nn.Module):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if get_pp_group().is_first_rank:
|
||||
|
||||
Reference in New Issue
Block a user