[Kernel] Register punica ops directly (#10522)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -6,12 +6,13 @@ maximum ranks.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
# Enable custom op register
|
||||
import vllm.lora.ops.bgmv_expand
|
||||
import vllm.lora.ops.bgmv_expand_slice
|
||||
import vllm.lora.ops.bgmv_shrink
|
||||
import vllm.lora.ops.sgmv_expand
|
||||
import vllm.lora.ops.sgmv_expand_slice
|
||||
import vllm.lora.ops.sgmv_shrink # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
@@ -37,6 +38,16 @@ def assert_close(a, b):
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
# Unlike test_punica_sizes.py, we directly utilize custom op for
|
||||
# testing, which verifies the correct registration of these ops.
|
||||
bgmv_expand = torch.ops.vllm.bgmv_expand
|
||||
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
|
||||
bgmv_shrink = torch.ops.vllm.bgmv_shrink
|
||||
sgmv_expand = torch.ops.vllm.sgmv_expand
|
||||
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
|
||||
sgmv_shrink = torch.ops.vllm.sgmv_shrink
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
|
||||
Reference in New Issue
Block a user