Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Custom normalization layers."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -14,13 +15,14 @@ from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER_RMSNORM \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
@@ -32,9 +34,13 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
@@ -44,9 +50,11 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
def poly_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
out,
|
||||
@@ -58,9 +66,11 @@ def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
|
||||
return out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
def rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
@@ -71,9 +81,11 @@ def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
@@ -89,14 +101,18 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
return output, residual_out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
def rocm_aiter_rms_norm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
@@ -116,7 +132,8 @@ if current_platform.is_rocm():
|
||||
|
||||
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
|
||||
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
|
||||
torch.float16, torch.bfloat16
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
if use_aiter and with_fused_add:
|
||||
@@ -150,8 +167,9 @@ class RMSNorm(CustomOp):
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.variance_epsilon = eps
|
||||
self.variance_size_override = (None if var_hidden_size == hidden_size
|
||||
else var_hidden_size)
|
||||
self.variance_size_override = (
|
||||
None if var_hidden_size == hidden_size else var_hidden_size
|
||||
)
|
||||
self.has_weight = has_weight
|
||||
if dtype is not None:
|
||||
self.weight = torch.ones(hidden_size, dtype=dtype)
|
||||
@@ -163,9 +181,11 @@ class RMSNorm(CustomOp):
|
||||
|
||||
if current_platform.is_rocm():
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=False, dtype=weight_dtype)
|
||||
with_fused_add=False, dtype=weight_dtype
|
||||
)
|
||||
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=True, dtype=weight_dtype)
|
||||
with_fused_add=True, dtype=weight_dtype
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -181,8 +201,10 @@ class RMSNorm(CustomOp):
|
||||
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError("Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}")
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
@@ -190,9 +212,10 @@ class RMSNorm(CustomOp):
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}")
|
||||
f"{self.variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[:, :, :self.variance_size_override]
|
||||
x_var = x[:, :, : self.variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
@@ -215,8 +238,9 @@ class RMSNorm(CustomOp):
|
||||
|
||||
add_residual = residual is not None
|
||||
if add_residual:
|
||||
return fused_add_rms_norm(x, residual, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
return fused_add_rms_norm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
else:
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
@@ -230,11 +254,11 @@ class RMSNorm(CustomOp):
|
||||
|
||||
add_residual = residual is not None
|
||||
if add_residual:
|
||||
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
return self.rocm_norm_func_with_add(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
else:
|
||||
return self.rocm_norm_func(x, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
@@ -315,8 +339,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return self.forward_static(self.weight.data, self.variance_epsilon, x,
|
||||
residual)
|
||||
return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -328,7 +351,8 @@ class GemmaRMSNorm(CustomOp):
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self.forward_static = torch.compile( # type: ignore
|
||||
self.forward_static)
|
||||
self.forward_static
|
||||
)
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
@@ -352,8 +376,7 @@ class PolyNorm(CustomOp):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x / torch.sqrt(
|
||||
x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||||
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -366,9 +389,12 @@ class PolyNorm(CustomOp):
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x_float = x.to(torch.float32)
|
||||
output = (self.weight[0] * self._norm(x_float**3) +
|
||||
self.weight[1] * self._norm(x_float**2) +
|
||||
self.weight[2] * self._norm(x_float) + self.bias)
|
||||
output = (
|
||||
self.weight[0] * self._norm(x_float**3)
|
||||
+ self.weight[1] * self._norm(x_float**2)
|
||||
+ self.weight[2] * self._norm(x_float)
|
||||
+ self.bias
|
||||
)
|
||||
return output.to(orig_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
@@ -391,5 +417,6 @@ class LayerNorm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
|
||||
self.eps).type_as(x)
|
||||
return F.layer_norm(
|
||||
x.float(), (self.dim,), self.weight, self.bias, self.eps
|
||||
).type_as(x)
|
||||
|
||||
Reference in New Issue
Block a user