[V1][Kernel] Add triton implementation for reshape_and_cache_flash (#24503)
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
527821d191
commit
100b630a60
@@ -9,6 +9,9 @@ import torch
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (
|
||||
@@ -31,6 +34,8 @@ def run_benchmark(
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
num_iters: int,
|
||||
implementation: str,
|
||||
benchmark_mode: str,
|
||||
device: str = "cuda",
|
||||
) -> float:
|
||||
"""Return latency (seconds) for given num_tokens."""
|
||||
@@ -38,6 +43,14 @@ def run_benchmark(
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||
|
||||
if implementation not in ("cuda", "triton"):
|
||||
raise ValueError(
|
||||
f"Unsupported implementation: {implementation}. "
|
||||
"Only 'cuda' and 'triton' are supported."
|
||||
)
|
||||
if implementation == "triton" and kv_cache_layout == "HND":
|
||||
return float("nan") # Triton does not support HND layout yet.
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@@ -65,27 +78,49 @@ def run_benchmark(
|
||||
cache_layout=kv_cache_layout,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
# to free unused memory
|
||||
del key_caches, value_caches
|
||||
|
||||
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
if implementation == "cuda":
|
||||
function_under_test = lambda: ops.reshape_and_cache_flash(
|
||||
key, # noqa: F821
|
||||
value, # noqa: F821
|
||||
key_cache, # noqa: F821
|
||||
value_cache, # noqa: F821
|
||||
slot_mapping, # noqa: F821
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
function_under_test = lambda: triton_reshape_and_cache_flash(
|
||||
key, # noqa: F821
|
||||
value, # noqa: F821
|
||||
key_cache, # noqa: F821
|
||||
value_cache, # noqa: F821
|
||||
slot_mapping, # noqa: F821
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
if benchmark_mode == "cudagraph":
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
function_under_test = lambda: g.replay()
|
||||
|
||||
def run_cuda_benchmark(n_iters: int) -> float:
|
||||
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_iters):
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
return (end - start) / n_iters
|
||||
|
||||
@@ -116,10 +151,16 @@ def main(args):
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
kv_cache_layout=layout,
|
||||
num_iters=args.iters,
|
||||
implementation=args.implementation,
|
||||
benchmark_mode=args.mode,
|
||||
device="cuda",
|
||||
)
|
||||
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
|
||||
|
||||
print(
|
||||
f"Benchmark results for implementation {args.implementation}"
|
||||
f" (measuring with {args.mode}):"
|
||||
)
|
||||
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
|
||||
|
||||
|
||||
@@ -151,6 +192,21 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument("--iters", type=int, default=100)
|
||||
|
||||
parser.add_argument(
|
||||
"--implementation",
|
||||
type=str,
|
||||
choices=["cuda", "triton"],
|
||||
default="cuda",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["cudagraph", "no_graph"],
|
||||
default="cudagraph",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user