[Build] Avoid building too many extensions (#1624)

This commit is contained in:
Yanming W
2023-11-23 16:31:19 -08:00
committed by GitHub
parent de23687d16
commit e0c6f556e8
25 changed files with 206 additions and 272 deletions

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import pos_encoding_ops
from vllm._C import ops
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -145,7 +145,7 @@ def test_rotary_embedding(
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
pos_encoding_ops.rotary_embedding(
ops.rotary_embedding(
positions,
out_query,
out_key,