[IR][RmsNorm] pass None if not has_weight (#38961)

Signed-off-by: Linkun Chen <github@lkchen.net>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Linkun
2026-04-04 08:02:30 -07:00
committed by GitHub
parent 2a36d8fb72
commit a88ce94bbb

View File

@@ -241,8 +241,12 @@ class RMSNorm(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
if residual is None:
# TODO(luka): address the weight=None passing issue more generally
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
x,
self.weight.data if self.has_weight else None,
self.variance_epsilon,
self.variance_size_override,
)
return self.forward_static(