[Misc] Add CustomOp interface for device portability (#5255)

This commit is contained in:
Woosuk Kwon
2024-06-05 09:18:19 -07:00
committed by GitHub
parent 974fc9b845
commit 41ca62cf03
7 changed files with 100 additions and 27 deletions

View File

@@ -6,14 +6,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(nn.Module):
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@@ -23,12 +23,14 @@ class SiluAndMul(nn.Module):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -36,7 +38,7 @@ class SiluAndMul(nn.Module):
return out
class GeluAndMul(nn.Module):
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
@@ -52,12 +54,14 @@ class GeluAndMul(nn.Module):
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -71,28 +75,32 @@ class GeluAndMul(nn.Module):
return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module):
class NewGELU(CustomOp):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
class FastGELU(nn.Module):
class FastGELU(CustomOp):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out