[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import align_to_256bytes
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@@ -18,6 +19,13 @@ NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
|
||||
# Parameters for MLA tests.
|
||||
KV_LORA_RANKS = [512]
|
||||
QK_ROPE_HEAD_DIMS = [64]
|
||||
NUM_TOKENS_MLA = [42]
|
||||
BLOCK_SIZES_MLA = [16]
|
||||
NUM_BLOCKS_MLA = [8]
|
||||
|
||||
# Arbitrary values for testing
|
||||
# don't make it too large. e.g. [1024, 36000] will OOM
|
||||
NUM_BLOCKS = [1024, 10000]
|
||||
@@ -432,3 +440,257 @@ def test_fp8_e4m3_conversion(
|
||||
ops.convert_fp8(converted_cache, cache_fp8)
|
||||
|
||||
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
def _create_mla_cache(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
entry_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
align_cache: bool,
|
||||
) -> torch.Tensor:
|
||||
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
|
||||
|
||||
if align_cache:
|
||||
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype)
|
||||
alloc_shape = (num_blocks, block_size, alloc_entry_size)
|
||||
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device)
|
||||
cache = cache_full[..., :entry_size]
|
||||
else:
|
||||
cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=cache_dtype,
|
||||
device=device)
|
||||
return cache
|
||||
|
||||
|
||||
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
||||
rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype
|
||||
|
||||
vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
|
||||
if kv_cache_dtype == "fp8":
|
||||
temp = torch.zeros_like(cache)
|
||||
ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
|
||||
vals = temp
|
||||
cache.copy_(vals)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False])
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
|
||||
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
|
||||
ops.convert_fp8(ref_kv_cache,
|
||||
ref_temp,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
else:
|
||||
ref_kv_cache = ref_temp
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_temp,
|
||||
kv_cache.contiguous(),
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(expected_temp,
|
||||
ref_kv_cache,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
torch.testing.assert_close(result_temp,
|
||||
expected_temp,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
kv_caches = []
|
||||
for _ in range(num_layers):
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
kv_caches.append(kv_cache)
|
||||
|
||||
ref_caches = [kv_cache.clone() for kv_cache in kv_caches]
|
||||
|
||||
num_mappings = min(2, num_blocks // 2)
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining, 2 * num_mappings)
|
||||
block_mapping = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device=device).view(-1, 2)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
for ref_cache in ref_caches:
|
||||
ref_cache[dst].copy_(ref_cache[src])
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.copy_blocks_mla,
|
||||
(kv_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
ops.copy_blocks_mla(kv_caches, block_mapping_tensor)
|
||||
|
||||
for kv_cache, ref_cache in zip(kv_caches, ref_caches):
|
||||
torch.testing.assert_close(kv_cache, ref_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_swap_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype)
|
||||
_fill_mla_cache(dst_cache, kv_cache_dtype)
|
||||
|
||||
src_cache_clone = src_cache.clone()
|
||||
|
||||
num_mappings = min(2, num_blocks // 2)
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining_blocks, num_mappings)
|
||||
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu").view(-1, 2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_cache, dst_cache, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(kv_lora_rank == KV_LORA_RANKS[0]
|
||||
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
torch.testing.assert_close(
|
||||
src_cache_clone[src].cpu(),
|
||||
dst_cache[dst].cpu(),
|
||||
msg=f"Block {src} from src should have been swapped to block "
|
||||
f"{dst} in dst_cache.")
|
||||
|
||||
Reference in New Issue
Block a user