[vLLM IR] rework gemma_rms_norm (#39014)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jiangyun Zhu
2026-04-07 16:37:00 +08:00
committed by GitHub
parent da4c0e4db9
commit 8060bb0333
8 changed files with 106 additions and 75 deletions

View File

@@ -376,77 +376,32 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
@staticmethod
def _forward_static_no_residual(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward() without residual."""
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x
@staticmethod
def _forward_static_with_residual(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward() with residual."""
orig_dtype = x.dtype
x = (
x.float() + residual.float()
if orig_dtype == torch.float16
else x + residual
)
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x, residual
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
if residual is None:
return self._forward_static_no_residual(
self.weight.data, self.variance_epsilon, x
)
else:
return self._forward_static_with_residual(
self.weight.data, self.variance_epsilon, x, residual
orig_dtype = x.dtype
weight = self.weight.data.float() + 1.0
if residual is not None:
x = (
x.float() + residual.float()
if orig_dtype == torch.float16
else x + residual
)
residual = x
# ir.ops.rms_norm handles fp32 upcast internally
out = ir.ops.rms_norm(x, weight, self.variance_epsilon)
return (
out.to(orig_dtype) if residual is None else (out.to(orig_dtype), residual)
)
def forward_cuda(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self._forward_static_no_residual = torch.compile( # type: ignore
self._forward_static_no_residual
)
self._forward_static_with_residual = torch.compile( # type: ignore
self._forward_static_with_residual
)
self._is_compiled = True
return self.forward_native(x, residual)