[V1][Kernel] Add triton implementation for reshape_and_cache_flash (#24503)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Burkhard Ringlein
2025-09-23 18:52:40 +02:00
committed by GitHub
parent 527821d191
commit 100b630a60
4 changed files with 276 additions and 20 deletions

View File

@@ -39,6 +39,8 @@ CUDA_DEVICES = [
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@@ -223,6 +225,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
@@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
implementation: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
@@ -298,12 +305,20 @@ def test_reshape_and_cache_flash(
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
if implementation == "cuda":
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)