[Build] Avoid building too many extensions (#1624)
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import layernorm_ops
|
||||
from vllm._C import ops
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@@ -29,7 +29,7 @@ class RMSNorm(nn.Module):
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
layernorm_ops.fused_add_rms_norm(
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
@@ -37,7 +37,7 @@ class RMSNorm(nn.Module):
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
layernorm_ops.rms_norm(
|
||||
ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
self.weight.data,
|
||||
|
||||
Reference in New Issue
Block a user