[Attention] Get rid of mla cache alignment (#14842)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -8,7 +8,6 @@ import torch
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import align_to_256bytes
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@@ -450,22 +449,13 @@ def _create_mla_cache(
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
align_cache: bool,
|
||||
) -> torch.Tensor:
|
||||
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
|
||||
|
||||
if align_cache:
|
||||
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype)
|
||||
alloc_shape = (num_blocks, block_size, alloc_entry_size)
|
||||
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device)
|
||||
cache = cache_full[..., :entry_size]
|
||||
else:
|
||||
cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=cache_dtype,
|
||||
device=device)
|
||||
return cache
|
||||
return torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=cache_dtype,
|
||||
device=device)
|
||||
|
||||
|
||||
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
||||
@@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False])
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_mla(
|
||||
kv_lora_rank: int,
|
||||
@@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
@@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
|
||||
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
kv_cache_dtype, device)
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
@@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
@@ -588,7 +575,6 @@ def test_copy_blocks_mla(
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
@@ -598,7 +584,7 @@ def test_copy_blocks_mla(
|
||||
kv_caches = []
|
||||
for _ in range(num_layers):
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
kv_cache_dtype, device)
|
||||
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
kv_caches.append(kv_cache)
|
||||
|
||||
@@ -642,7 +628,6 @@ def test_copy_blocks_mla(
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("align_cache", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_swap_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
@@ -653,7 +638,6 @@ def test_swap_blocks_mla(
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
align_cache: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
@@ -661,9 +645,9 @@ def test_swap_blocks_mla(
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
kv_cache_dtype, device)
|
||||
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
kv_cache_dtype, device)
|
||||
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype)
|
||||
_fill_mla_cache(dst_cache, kv_cache_dtype)
|
||||
@@ -704,15 +688,14 @@ def test_swap_blocks_mla(
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("align_cache", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, align_cache, device):
|
||||
kv_cache_dtype, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
kv_cache_dtype, device)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0,
|
||||
|
||||
Reference in New Issue
Block a user