[Misc] Add CustomOp interface for device portability (#5255)

This commit is contained in:
Woosuk Kwon
2024-06-05 09:18:19 -07:00
committed by GitHub
parent 974fc9b845
commit 41ca62cf03
7 changed files with 100 additions and 27 deletions

View File

@@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
class RMSNorm(nn.Module):
class RMSNorm(CustomOp):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
@@ -23,7 +23,7 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def _forward(
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
@@ -43,11 +43,13 @@ class RMSNorm(nn.Module):
else:
return x, residual
def forward(
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm import _custom_ops as ops
if residual is not None:
ops.fused_add_rms_norm(
x,