[V1][Kernel] Add triton implementation for reshape_and_cache_flash (#24503)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Burkhard Ringlein
2025-09-23 18:52:40 +02:00
committed by GitHub
parent 527821d191
commit 100b630a60
4 changed files with 276 additions and 20 deletions

View File

@@ -9,6 +9,9 @@ import torch
from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
@@ -31,6 +34,8 @@ def run_benchmark(
kv_cache_dtype: str,
kv_cache_layout: str,
num_iters: int,
implementation: str,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""
@@ -38,6 +43,14 @@ def run_benchmark(
if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
if implementation not in ("cuda", "triton"):
raise ValueError(
f"Unsupported implementation: {implementation}. "
"Only 'cuda' and 'triton' are supported."
)
if implementation == "triton" and kv_cache_layout == "HND":
return float("nan") # Triton does not support HND layout yet.
current_platform.seed_everything(42)
torch.set_default_device(device)
@@ -65,27 +78,49 @@ def run_benchmark(
cache_layout=kv_cache_layout,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
if implementation == "cuda":
function_under_test = lambda: ops.reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
else:
function_under_test = lambda: triton_reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
torch.cuda.synchronize()
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters
@@ -116,10 +151,16 @@ def main(args):
kv_cache_dtype=args.kv_cache_dtype,
kv_cache_layout=layout,
num_iters=args.iters,
implementation=args.implementation,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
print(
f"Benchmark results for implementation {args.implementation}"
f" (measuring with {args.mode}):"
)
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
@@ -151,6 +192,21 @@ if __name__ == "__main__":
)
parser.add_argument("--iters", type=int, default=100)
parser.add_argument(
"--implementation",
type=str,
choices=["cuda", "triton"],
default="cuda",
)
parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)
args = parser.parse_args()
main(args)