Allocate kv_cache with stride order (#16605)

Signed-off-by: shuw <shuw@nvidia.com>
This commit is contained in:
Shu Wang
2025-04-26 00:03:31 -05:00
committed by GitHub
parent b278911229
commit 9e96f56efb
6 changed files with 119 additions and 50 deletions

View File

@@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 120, 256]
BLOCK_SIZES = [8, 16, 32]
CACHE_LAYOUTS = ["NHD", "HND"]
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
seed: int,
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens = num_tokens // 2
num_blocks = num_blocks // 2
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
qkv = torch.randn(num_tokens,
3,
num_heads,
@@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
kv_cache_dtype,
dtype,
device=device,
cache_layout=kv_cache_layout,
)
key_cache, value_cache = key_caches[0].contiguous(
), value_caches[0].contiguous()
key_cache, value_cache = key_caches[0], value_caches[0]
del key_caches
del value_caches
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
def permute_and_compact(x):
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
return y.contiguous()
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
cloned_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache_compact,
v_scale.item(), kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
@@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
result_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache,
key_cache_compact,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
result_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype)
@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
for i in range(num_tokens):
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
if kv_cache_layout == "NHD":
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
else:
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache,
@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
atol=0.001,
rtol=0.1)
else:
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION)