[Model] Support NVLM-D and fix QK Norm in InternViT (#9045)

Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Cyrus Leung
2024-10-07 19:55:12 +08:00
committed by GitHub
parent f19da64871
commit 151ef4efd2
12 changed files with 518 additions and 236 deletions

View File

@@ -18,10 +18,16 @@ class RMSNorm(CustomOp):
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward_native(
self,
@@ -35,7 +41,23 @@ class RMSNorm(CustomOp):
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
@@ -48,6 +70,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm import _custom_ops as ops
if residual is not None:
@@ -72,6 +97,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm._ipex_ops import ipex_ops as ops
if residual is not None: