[IR][RmsNorm] pass None if not has_weight (#38961)
Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -241,8 +241,12 @@ class RMSNorm(CustomOp):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
if residual is None:
|
||||
# TODO(luka): address the weight=None passing issue more generally
|
||||
return ir.ops.rms_norm(
|
||||
x, self.weight.data, self.variance_epsilon, self.variance_size_override
|
||||
x,
|
||||
self.weight.data if self.has_weight else None,
|
||||
self.variance_epsilon,
|
||||
self.variance_size_override,
|
||||
)
|
||||
|
||||
return self.forward_static(
|
||||
|
||||
Reference in New Issue
Block a user