[Misc] Remove unused custom ops copy_blocks and copy_blocks_mla (#30967)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
@@ -40,93 +40,6 @@ KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
num_mappings: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
pytest.skip()
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 2 * num_mappings <= num_blocks
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
|
||||
block_mapping: list[tuple[int, int]] = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_layers,
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device=device
|
||||
).view(-1, 2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.copy_blocks,
|
||||
(key_caches, value_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(head_size == HEAD_SIZES[0]),
|
||||
)
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
||||
|
||||
# Run the reference implementation.
|
||||
for src, dst in block_mapping:
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -763,73 +676,6 @@ def test_concat_and_cache_ds_mla(
|
||||
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
kv_caches = []
|
||||
for _ in range(num_layers):
|
||||
kv_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
kv_caches.append(kv_cache)
|
||||
|
||||
ref_caches = [kv_cache.clone() for kv_cache in kv_caches]
|
||||
|
||||
num_mappings = min(2, num_blocks // 2)
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remaining = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining, 2 * num_mappings)
|
||||
block_mapping = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device=device
|
||||
).view(-1, 2)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
for ref_cache in ref_caches:
|
||||
ref_cache[dst].copy_(ref_cache[src])
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.copy_blocks_mla,
|
||||
(kv_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
ops.copy_blocks_mla(kv_caches, block_mapping_tensor)
|
||||
|
||||
for kv_cache, ref_cache in zip(kv_caches, ref_caches):
|
||||
torch.testing.assert_close(kv_cache, ref_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
|
||||
Reference in New Issue
Block a user