[Kernel] Register punica ops directly (#10522)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-11-22 01:18:11 +08:00
committed by GitHub
parent da7e702c6f
commit 2385b60d83
7 changed files with 157 additions and 24 deletions

View File

@@ -6,12 +6,13 @@ maximum ranks.
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
# Enable custom op register
import vllm.lora.ops.bgmv_expand
import vllm.lora.ops.bgmv_expand_slice
import vllm.lora.ops.bgmv_shrink
import vllm.lora.ops.sgmv_expand
import vllm.lora.ops.sgmv_expand_slice
import vllm.lora.ops.sgmv_shrink # noqa: F401
from vllm.platforms import current_platform
from .utils import (generate_data, generate_data_for_expand_nslices,
@@ -37,6 +38,16 @@ def assert_close(a, b):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
# Unlike test_punica_sizes.py, we directly utilize custom op for
# testing, which verifies the correct registration of these ops.
bgmv_expand = torch.ops.vllm.bgmv_expand
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
bgmv_shrink = torch.ops.vllm.bgmv_shrink
sgmv_expand = torch.ops.vllm.sgmv_expand
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
sgmv_shrink = torch.ops.vllm.sgmv_shrink
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)