[V0 deprecation] Remove V0 HPU backend (#21131)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -170,26 +170,6 @@ class RMSNorm(CustomOp):
|
||||
else:
|
||||
return norm_func(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
from vllm_hpu_extension.kernels import rms_norm
|
||||
HPUFusedRMSNorm = rms_norm()
|
||||
if HPUFusedRMSNorm is None:
|
||||
return self.forward_native(x, residual)
|
||||
if residual is not None:
|
||||
orig_shape = x.shape
|
||||
residual += x.view(residual.shape)
|
||||
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||
x = HPUFusedRMSNorm.apply(residual, self.weight,
|
||||
self.variance_epsilon)
|
||||
return x.view(orig_shape), residual
|
||||
|
||||
x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
|
||||
return x
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user