[kernel][perf] support uncontiguous input for rms_norm kernel (#28103)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Signed-off-by: izhuhaoran <izhuhaoran@qq.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp):
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result = torch.empty_like(input)
|
||||
# TODO: support non-contiguous input for RMSNorm and remove this
|
||||
input_contiguous = input.contiguous()
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result,
|
||||
input=input_contiguous,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user