[Build] Avoid building too many extensions (#1624)

This commit is contained in:
Yanming W
2023-11-23 16:31:19 -08:00
committed by GitHub
parent de23687d16
commit e0c6f556e8
25 changed files with 206 additions and 272 deletions

View File

@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import layernorm_ops
from vllm._C import ops
class RMSNorm(nn.Module):
@@ -29,7 +29,7 @@ class RMSNorm(nn.Module):
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
layernorm_ops.fused_add_rms_norm(
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
@@ -37,7 +37,7 @@ class RMSNorm(nn.Module):
)
return x, residual
out = torch.empty_like(x)
layernorm_ops.rms_norm(
ops.rms_norm(
out,
x,
self.weight.data,