[V1][Kernel] Add triton implementation for reshape_and_cache_flash (#24503)
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
527821d191
commit
100b630a60
@@ -39,6 +39,8 @@ CUDA_DEVICES = [
|
||||
# We assume fp8 is always enabled for testing.
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
|
||||
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@@ -223,6 +225,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
|
||||
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_flash(
|
||||
kv_cache_factory_flashinfer,
|
||||
@@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
implementation: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
assert implementation in ["cuda", "triton"]
|
||||
if implementation == "triton" and kv_cache_layout == "HND":
|
||||
pytest.skip("Triton implementation only supports NHD layout.")
|
||||
|
||||
# fp8 conversion requires continugous memory buffer. Reduce the number of
|
||||
# blocks and tokens to consume less memory.
|
||||
@@ -298,12 +305,20 @@ def test_reshape_and_cache_flash(
|
||||
cloned_key_cache = key_cache_compact.clone()
|
||||
cloned_value_cache = value_cache_compact.clone()
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
||||
if implementation == "cuda":
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
elif implementation == "triton":
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash)
|
||||
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user