Allocate kv_cache with stride order (#16605)
Signed-off-by: shuw <shuw@nvidia.com>
This commit is contained in:
@@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
CACHE_LAYOUTS = ["NHD", "HND"]
|
||||
|
||||
# Parameters for MLA tests.
|
||||
KV_LORA_RANKS = [512]
|
||||
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_flash(
|
||||
kv_cache_factory_flashinfer,
|
||||
@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# fp8 conversion requires continugous memory buffer. Reduce the number of
|
||||
# blocks and tokens to consume less memory.
|
||||
num_tokens = num_tokens // 2
|
||||
num_blocks = num_blocks // 2
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
@@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
cache_layout=kv_cache_layout,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0].contiguous(
|
||||
), value_caches[0].contiguous()
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
def permute_and_compact(x):
|
||||
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
|
||||
return y.contiguous()
|
||||
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
|
||||
kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
|
||||
cloned_key_cache = torch.empty_like(key_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
|
||||
kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache_compact,
|
||||
v_scale.item(), kv_cache_dtype)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
cloned_key_cache = key_cache_compact.clone()
|
||||
cloned_value_cache = value_cache_compact.clone()
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
@@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
result_key_cache = torch.empty_like(key_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache,
|
||||
key_cache,
|
||||
key_cache_compact,
|
||||
k_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
result_value_cache = torch.empty_like(value_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache,
|
||||
value_cache,
|
||||
value_cache_compact,
|
||||
v_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
|
||||
@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
|
||||
for i in range(num_tokens):
|
||||
block_idx = block_indicies_lst[i]
|
||||
block_offset = block_offsets_lst[i]
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
if kv_cache_layout == "NHD":
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
else:
|
||||
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
|
||||
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
torch.testing.assert_close(result_key_cache,
|
||||
@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
|
||||
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||
|
||||
Reference in New Issue
Block a user