Revert "[vLLM IR] gemma_rms_norm" (#38998)
This commit is contained in:
@@ -16,6 +16,7 @@ def rms_norm(
|
||||
x_var = x if variance_size is None else x[..., :variance_size]
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if weight is not None:
|
||||
x = x.to(weight.dtype) * weight
|
||||
return x.to(orig_dtype)
|
||||
x = x * weight
|
||||
return x
|
||||
|
||||
@@ -36,11 +36,13 @@ AITER_SUPPORTED = is_aiter_found()
|
||||
|
||||
rms_no_var_16bit_only = (
|
||||
lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
and x.dtype in (torch.float16, torch.bfloat16)
|
||||
and (weight is None or weight.dtype == x.dtype)
|
||||
and x.dtype
|
||||
in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
"""AITER rms_norm only supports float16 and bfloat16 acts, no var_size override,
|
||||
and requires weight dtype to match x dtype."""
|
||||
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
|
||||
|
||||
|
||||
@ir.ops.rms_norm.register_impl(
|
||||
|
||||
@@ -11,11 +11,8 @@ current_platform.import_kernels()
|
||||
CUDA_ALIKE = current_platform.is_cuda_alike()
|
||||
"""Most kernels in this file are supported on all CUDA-alike platforms."""
|
||||
|
||||
rms_no_var_size = (
|
||||
lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
and (weight is None or weight.dtype == x.dtype)
|
||||
)
|
||||
"""vLLM kernel does not support variance_size parameter or mismatched weight dtype."""
|
||||
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
"""vLLM kernel does not support variance_size parameter."""
|
||||
|
||||
|
||||
@ir.ops.rms_norm.register_impl(
|
||||
|
||||
@@ -18,9 +18,7 @@ def is_xpu_kernels_found() -> bool:
|
||||
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
|
||||
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
|
||||
|
||||
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None and (
|
||||
weight is None or weight.dtype == x.dtype
|
||||
)
|
||||
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
|
||||
|
||||
|
||||
@ir.ops.rms_norm.register_impl(
|
||||
|
||||
@@ -376,17 +376,29 @@ class GemmaRMSNorm(CustomOp):
|
||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@staticmethod
|
||||
def _forward_static_no_residual(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
if residual is None:
|
||||
return ir.ops.rms_norm(
|
||||
x, self.weight.data.float() + 1.0, self.variance_epsilon
|
||||
)
|
||||
else:
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward() without residual."""
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x * (1.0 + weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _forward_static_with_residual(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward() with residual."""
|
||||
orig_dtype = x.dtype
|
||||
x = (
|
||||
x.float() + residual.float()
|
||||
@@ -394,15 +406,47 @@ class GemmaRMSNorm(CustomOp):
|
||||
else x + residual
|
||||
)
|
||||
residual = x
|
||||
return ir.ops.rms_norm(
|
||||
x, self.weight.data.float() + 1.0, self.variance_epsilon
|
||||
).to(orig_dtype), residual
|
||||
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
x = x * (1.0 + weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
if residual is None:
|
||||
return self._forward_static_no_residual(
|
||||
self.weight.data, self.variance_epsilon, x
|
||||
)
|
||||
else:
|
||||
return self._forward_static_with_residual(
|
||||
self.weight.data, self.variance_epsilon, x, residual
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self._forward_static_no_residual = torch.compile( # type: ignore
|
||||
self._forward_static_no_residual
|
||||
)
|
||||
self._forward_static_with_residual = torch.compile( # type: ignore
|
||||
self._forward_static_with_residual
|
||||
)
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user