[Kernel] Use flashinfer for decoding (#4353)

Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2024-05-03 15:51:27 -07:00
committed by GitHub
parent f8e7adda21
commit 43c413ec57
15 changed files with 600 additions and 53 deletions

View File

@@ -5,6 +5,7 @@ import pytest
import torch
from vllm import _custom_ops as ops
from vllm._C import cache_ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@@ -191,6 +192,82 @@ def test_reshape_and_cache(
assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8":
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=device)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory_flashinfer(
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)