Refactor llama family models (#2637)

This commit is contained in:
Roy
2024-02-13 16:09:23 +08:00
committed by GitHub
parent f964493274
commit 5c976a7e1a
17 changed files with 236 additions and 2720 deletions

View File

@@ -7,6 +7,31 @@ import torch.nn as nn
from vllm._C import ops
class LayerNorm(nn.LayerNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__(hidden_size, eps=eps)
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""normalization."""
if residual is not None:
x = x + residual
residual = x
x = super().forward(x)
if residual is None:
return x
else:
return x, residual
class RMSNorm(nn.Module):
"""Root mean square normalization.