[kernel][perf] support uncontiguous input for rms_norm kernel (#28103)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
zhrrr
2025-11-21 11:39:09 +08:00
committed by GitHub
parent 0e741c12e3
commit a982f5b5ea
4 changed files with 77 additions and 33 deletions

View File

@@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
# TODO: support non-contiguous input for RMSNorm and remove this
input_contiguous = input.contiguous()
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input_contiguous,
input=input,
weight=weight,
epsilon=self.epsilon,
)