[kernel][perf] support uncontiguous input for rms_norm kernel (#28103)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
zhrrr
2025-11-21 11:39:09 +08:00
committed by GitHub
parent 0e741c12e3
commit a982f5b5ea
4 changed files with 77 additions and 33 deletions

View File

@@ -328,10 +328,7 @@ def rotary_embedding(
def rms_norm(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
# If removed, also need to remove contiguous in MatcherRMSNorm
input_contiguous = input.contiguous()
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
torch.ops._C.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(

View File

@@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
# TODO: support non-contiguous input for RMSNorm and remove this
input_contiguous = input.contiguous()
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input_contiguous,
input=input,
weight=weight,
epsilon=self.epsilon,
)