[kernel][perf] support uncontiguous input for rms_norm kernel (#28103)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Signed-off-by: izhuhaoran <izhuhaoran@qq.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -328,10 +328,7 @@ def rotary_embedding(
|
||||
def rms_norm(
|
||||
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
|
||||
) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
|
||||
# If removed, also need to remove contiguous in MatcherRMSNorm
|
||||
input_contiguous = input.contiguous()
|
||||
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
|
||||
torch.ops._C.rms_norm(out, input, weight, epsilon)
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
|
||||
@@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp):
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result = torch.empty_like(input)
|
||||
# TODO: support non-contiguous input for RMSNorm and remove this
|
||||
input_contiguous = input.contiguous()
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result,
|
||||
input=input_contiguous,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user