[IR][RmsNorm] pass None if not has_weight (#38961)
Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -241,8 +241,12 @@ class RMSNorm(CustomOp):
|
|||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
if residual is None:
|
if residual is None:
|
||||||
|
# TODO(luka): address the weight=None passing issue more generally
|
||||||
return ir.ops.rms_norm(
|
return ir.ops.rms_norm(
|
||||||
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
x,
|
||||||
|
self.weight.data if self.has_weight else None,
|
||||||
|
self.variance_epsilon,
|
||||||
|
self.variance_size_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.forward_static(
|
return self.forward_static(
|
||||||
|
|||||||
Reference in New Issue
Block a user