[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -6,18 +6,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
def rms_norm(
|
||||
@@ -58,80 +53,34 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
def poly_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
return x.reshape(x_original_shape)
|
||||
|
||||
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
output = torch.empty_like(x)
|
||||
rocm_aiter.rmsnorm2d_fwd_with_add(
|
||||
output, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual_out, # residual output
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
return out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=rocm_aiter_rms_norm_impl,
|
||||
fake_impl=rocm_aiter_rms_norm_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
|
||||
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
|
||||
def dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
|
||||
):
|
||||
use_aiter = use_aiter and dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
if use_aiter and with_fused_add:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
return rocm_aiter_ops.rms_norm2d_with_add
|
||||
if use_aiter:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm
|
||||
return rocm_aiter_ops.rms_norm
|
||||
|
||||
# fall back to CUDA implementation
|
||||
if with_fused_add:
|
||||
@@ -169,11 +118,14 @@ class RMSNorm(CustomOp):
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=False, dtype=weight_dtype
|
||||
with_fused_add=False,
|
||||
dtype=weight_dtype,
|
||||
use_aiter=aiter_rmsnorm_enabled,
|
||||
)
|
||||
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=True, dtype=weight_dtype
|
||||
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user