Upstream Llama4 Support to Main (#16113)
Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com> Signed-off-by: Chris Thi <chris.c.thi@gmail.com> Signed-off-by: drisspg <drisspguessous@gmail.com> Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Keyun Tong <tongkeyun@gmail.com> Signed-off-by: Lu Fang <fanglu@meta.com> Signed-off-by: Xiaodong Wang <xdwang@meta.com> Signed-off-by: Yang Chen <yangche@fb.com> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Lu Fang <fanglu@fb.com> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -92,6 +92,7 @@ class RMSNorm(CustomOp):
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -100,8 +101,10 @@ class RMSNorm(CustomOp):
|
||||
self.variance_size_override = (None if var_hidden_size == hidden_size
|
||||
else var_hidden_size)
|
||||
self.has_weight = has_weight
|
||||
|
||||
self.weight = torch.ones(hidden_size)
|
||||
if dtype is not None:
|
||||
self.weight = torch.ones(hidden_size, dtype=dtype)
|
||||
else:
|
||||
self.weight = torch.ones(hidden_size)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user