[CI][Models] Add VLM Support for Sequence Classification Conversion (#32885)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -278,21 +278,35 @@ class GemmaRMSNorm(CustomOp):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
def _forward_static_no_residual(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward() without residual."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = (
|
||||
x.float() + residual.float()
|
||||
if orig_dtype == torch.float16
|
||||
else x + residual
|
||||
)
|
||||
residual = x
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x * (1.0 + weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _forward_static_with_residual(
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward() with residual."""
|
||||
orig_dtype = x.dtype
|
||||
x = (
|
||||
x.float() + residual.float()
|
||||
if orig_dtype == torch.float16
|
||||
else x + residual
|
||||
)
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
@@ -301,7 +315,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
x = x * (1.0 + weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
return x, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -309,7 +323,14 @@ class GemmaRMSNorm(CustomOp):
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
|
||||
if residual is None:
|
||||
return self._forward_static_no_residual(
|
||||
self.weight.data, self.variance_epsilon, x
|
||||
)
|
||||
else:
|
||||
return self._forward_static_with_residual(
|
||||
self.weight.data, self.variance_epsilon, x, residual
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -320,8 +341,11 @@ class GemmaRMSNorm(CustomOp):
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self.forward_static = torch.compile( # type: ignore
|
||||
self.forward_static
|
||||
self._forward_static_no_residual = torch.compile( # type: ignore
|
||||
self._forward_static_no_residual
|
||||
)
|
||||
self._forward_static_with_residual = torch.compile( # type: ignore
|
||||
self._forward_static_with_residual
|
||||
)
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
Reference in New Issue
Block a user