[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -178,14 +178,11 @@ class RMSNorm(CustomOp):
|
||||
self.variance_size_override = (
|
||||
None if var_hidden_size == hidden_size else var_hidden_size
|
||||
)
|
||||
weight_dtype = dtype or torch.get_default_dtype()
|
||||
self.has_weight = has_weight
|
||||
if dtype is not None:
|
||||
self.weight = torch.ones(hidden_size, dtype=dtype)
|
||||
else:
|
||||
self.weight = torch.ones(hidden_size)
|
||||
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
weight_dtype = self.weight.data.dtype
|
||||
|
||||
if current_platform.is_rocm():
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
@@ -195,46 +192,68 @@ class RMSNorm(CustomOp):
|
||||
with_fused_add=True, dtype=weight_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
x: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
hidden_size: int,
|
||||
orig_dtype: torch.dtype,
|
||||
weight: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
variance_size_override: int | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
# residual promoted f16->f32 automatically,
|
||||
# otherwise Inductor eliminates the casts to and from f16,
|
||||
# increasing memory usage (and complicating pattern matching)
|
||||
x = x + residual
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
if x.shape[-1] != hidden_size:
|
||||
raise ValueError(
|
||||
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
|
||||
)
|
||||
|
||||
if variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[:, :, :variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[:, :, : self.variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if self.has_weight:
|
||||
x = x * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
return self.forward_static(
|
||||
x,
|
||||
self.variance_epsilon,
|
||||
self.hidden_size,
|
||||
x.dtype,
|
||||
self.weight.data if self.has_weight else None,
|
||||
residual,
|
||||
self.variance_size_override,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user