Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,7 +10,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
DTYPES = [torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
@@ -32,9 +32,7 @@ NUM_BLOCKS = [1024, 10000]
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
@@ -85,24 +83,33 @@ def test_copy_blocks(
block_mapping.append((src, dst2))
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads,
head_size, kv_cache_dtype,
dtype, seed, device)
key_caches, value_caches = kv_cache_factory(
num_blocks,
block_size,
num_layers,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
opcheck(torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]))
opcheck(
torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]),
)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation.
@@ -115,8 +122,7 @@ def test_copy_blocks(
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
torch.testing.assert_close(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
torch.testing.assert_close(value_cache, cloned_value_cache)
@@ -155,10 +161,17 @@ def test_reshape_and_cache(
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_caches, value_caches = kv_cache_factory(
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
@@ -176,12 +189,30 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
opcheck(
torch.ops._C_cache_ops.reshape_and_cache,
(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
),
cond=(head_size == HEAD_SIZES[0]),
)
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
@@ -202,14 +233,12 @@ def test_reshape_and_cache(
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
)
torch.testing.assert_close(
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
)
else:
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)
@@ -254,15 +283,8 @@ def test_reshape_and_cache_flash(
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
@@ -293,48 +315,73 @@ def test_reshape_and_cache_flash(
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache_compact,
v_scale.item(), kv_cache_dtype)
cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
)
cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
)
else:
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
if implementation == "cuda":
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
opcheck(
torch.ops._C_cache_ops.reshape_and_cache_flash,
(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
),
cond=(head_size == HEAD_SIZES[0]),
)
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
triton_reshape_and_cache_flash,
)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache_compact,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype)
result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
)
result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
result_value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype,
)
# Run the reference implementation.
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
@@ -352,14 +399,12 @@ def test_reshape_and_cache_flash(
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
)
torch.testing.assert_close(
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
)
else:
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
@@ -396,8 +441,8 @@ def test_swap_blocks(
current_platform.seed_everything(seed)
src_device = device if direction[0] == "cuda" else 'cpu'
dst_device = device if direction[1] == "cuda" else 'cpu'
src_device = device if direction[0] == "cuda" else "cpu"
dst_device = device if direction[1] == "cuda" else "cpu"
src_blocks = random.sample(range(num_blocks), num_mappings)
# For the same device, mapping must not overlap
@@ -408,42 +453,62 @@ def test_swap_blocks(
dst_blocks = random.sample(range(num_blocks), num_mappings)
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device="cpu"
).view(-1, 2)
# Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
seed, src_device)
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
src_device,
)
# Create the KV caches on the second device.
dist_key_caches, dist_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
seed, dst_device)
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
dst_device,
)
src_key_caches_clone = src_key_caches[0].clone()
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
do_opcheck = (head_size == HEAD_SIZES[0])
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
cond=do_opcheck)
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
cond=do_opcheck)
do_opcheck = head_size == HEAD_SIZES[0]
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
cond=do_opcheck,
)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
cond=do_opcheck,
)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
for src, dst in block_mapping:
torch.testing.assert_close(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
torch.testing.assert_close(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())
torch.testing.assert_close(
src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
)
torch.testing.assert_close(
src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -489,11 +554,9 @@ def _create_mla_cache(
device: str,
) -> torch.Tensor:
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
return torch.zeros(num_blocks,
block_size,
entry_size,
dtype=cache_dtype,
device=device)
return torch.zeros(
num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
)
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@@ -533,20 +596,16 @@ def test_concat_and_cache_mla(
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
@@ -558,10 +617,7 @@ def test_concat_and_cache_mla(
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
@@ -571,24 +627,18 @@ def test_concat_and_cache_mla(
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
if kv_cache_dtype == "fp8":
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
ops.convert_fp8(result_temp,
kv_cache.contiguous(),
scale.item(),
kv_dtype=kv_cache_dtype)
ops.convert_fp8(
result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
)
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
ops.convert_fp8(expected_temp,
ref_kv_cache,
scale.item(),
kv_dtype=kv_cache_dtype)
torch.testing.assert_close(result_temp,
expected_temp,
atol=0.001,
rtol=0.1)
ops.convert_fp8(
expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
)
torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
else:
torch.testing.assert_close(kv_cache, ref_kv_cache)
@@ -620,24 +670,21 @@ def test_concat_and_cache_ds_mla(
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device)
kv_cache = _create_mla_cache(
num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device,
)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device)
@@ -664,14 +711,16 @@ def test_concat_and_cache_ds_mla(
manual_max = abs(tile_data_float[0])
for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448.
tile_scale = manual_max / 448.0
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8")
ops.convert_fp8(
ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8",
)
for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
@@ -682,8 +731,7 @@ def test_concat_and_cache_ds_mla(
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
for i in range(num_tokens):
slot = slot_mapping[i].item()
@@ -694,12 +742,14 @@ def test_concat_and_cache_ds_mla(
kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
kv_scales = kv_cache_slice.view(torch.float32)[
kv_lora_rank // 4 : kv_lora_rank // 4 + 4
]
ref_scales = ref_cache_slice.view(torch.float32)[
kv_lora_rank // 4 : kv_lora_rank // 4 + 4
]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
@@ -734,8 +784,9 @@ def test_copy_blocks_mla(
kv_caches = []
for _ in range(num_layers):
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
kv_caches.append(kv_cache)
@@ -752,9 +803,9 @@ def test_copy_blocks_mla(
dst2 = dst_blocks[2 * i + 1]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
for src, dst in block_mapping:
for ref_cache in ref_caches:
@@ -795,10 +846,12 @@ def test_swap_blocks_mla(
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
src_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
dst_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(src_cache, kv_cache_dtype)
_fill_mla_cache(dst_cache, kv_cache_dtype)
@@ -810,9 +863,9 @@ def test_swap_blocks_mla(
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining_blocks, num_mappings)
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device="cpu"
).view(-1, 2)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
@@ -827,7 +880,8 @@ def test_swap_blocks_mla(
src_cache_clone[src].cpu(),
dst_cache[dst].cpu(),
msg=f"Block {src} from src should have been swapped to block "
f"{dst} in dst_cache.")
f"{dst} in dst_cache.",
)
@pytest.mark.parametrize("kv_lora_rank", [512])
@@ -840,32 +894,36 @@ def test_swap_blocks_mla(
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
block_size, num_blocks,
max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_gather_and_maybe_dequant_cache_mla(
kv_lora_rank,
qk_rope_head_dim,
block_size,
num_blocks,
max_seq_len,
batch_size,
dtype,
kv_cache_dtype,
device,
):
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
src_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
block_table = torch.empty(
(batch_size, num_blocks), dtype=torch.int32, device=device
)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
@@ -893,10 +951,8 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
remaining = s - (tot - 1) * block_size
last_block_data = src_cache[blocks[-1], :remaining, :]
if kv_cache_dtype == "fp8":
dequantized_last_block = torch.empty_like(last_block_data,
dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data,
scale.item())
dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
gathered_rows.append(dequantized_last_block)
else:
gathered_rows.append(last_block_data)
@@ -907,14 +963,29 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
opcheck(
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, None),
(
src_cache,
dst,
block_table,
cu_seq_lens,
batch_size,
kv_cache_dtype,
scale,
None,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, kv_cache_dtype,
scale, None)
ops.gather_and_maybe_dequant_cache(
src_cache,
dst,
block_table,
cu_seq_lens,
batch_size,
kv_cache_dtype,
scale,
None,
)
torch.testing.assert_close(dst, expected)
@@ -925,42 +996,46 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize(
"kv_cache_dtype", ["auto"]
) # You can also test "fp8" if needed.
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_cp_gather_cache_mla(
kv_lora_rank,
qk_rope_head_dim,
block_size,
num_blocks,
max_seq_len,
batch_size,
dtype,
kv_cache_dtype,
device,
):
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
src_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
block_table = torch.empty(
(batch_size, num_blocks), dtype=torch.int32, device=device
)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)
dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
expected_batches = []
for b in range(batch_size):
@@ -1016,20 +1091,16 @@ def test_concat_and_cache_mla_cpu(
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
@@ -1041,10 +1112,7 @@ def test_concat_and_cache_mla_cpu(
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
@@ -1054,6 +1122,5 @@ def test_concat_and_cache_mla_cpu(
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache)