[FEAT] [ROCm]: Add AITER RMS Norm (Layer Norm) Feature (#14959)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -5,7 +5,77 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||
return current_platform.is_rocm() \
|
||||
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
|
||||
import aiter as rocm_aiter
|
||||
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def rocm_aiter_fused_add_rms_norm(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
import aiter as rocm_aiter
|
||||
|
||||
# Assuming the correct signature for rmsnorm2d_fwd_with_add
|
||||
rocm_aiter.rmsnorm2d_fwd_with_add(
|
||||
x, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual, # residual output
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
|
||||
def dispatch_cuda_rmsnorm_func(add_residual: bool):
|
||||
if add_residual:
|
||||
if is_rocm_aiter_rmsnorm_enabled():
|
||||
return rocm_aiter_fused_add_rms_norm
|
||||
return fused_add_rms_norm
|
||||
|
||||
if is_rocm_aiter_rmsnorm_enabled():
|
||||
return rocm_aiter_rms_norm
|
||||
return rms_norm
|
||||
|
||||
|
||||
@CustomOp.register("rms_norm")
|
||||
@@ -81,24 +151,14 @@ class RMSNorm(CustomOp):
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
add_residual = residual is not None
|
||||
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
|
||||
|
||||
if residual is not None:
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
if add_residual:
|
||||
return norm_func(x, residual, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
return norm_func(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user