Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -7,19 +7,20 @@ import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton)
merge_attn_states as merge_attn_states_triton,
)
from vllm.platforms import current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states_torch(
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
):
p_lse = prefix_lse
s_lse = suffix_lse
@@ -32,15 +33,13 @@ def merge_attn_states_torch(
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
s_lse_exp = torch.exp(s_lse)
out_se = (p_lse_exp + s_lse_exp)
out_se = p_lse_exp + s_lse_exp
if output_lse is not None:
output_lse = torch.log(out_se) + max_lse
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0,
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0,
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output = prefix_output * p_scale + suffix_output * s_scale
return output, output_lse
@@ -55,8 +54,10 @@ all_case_info: list[tuple] = []
def generate_markdown_table():
global all_case_info
table_header = ("| tokens | heads | headsize | dtype "
"| device | torch | triton | cuda | speedup |")
table_header = (
"| tokens | heads | headsize | dtype "
"| device | torch | triton | cuda | speedup |"
)
table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |"
def shortly_dtype(dtype: torch.dtype) -> str:
@@ -68,16 +69,26 @@ def generate_markdown_table():
print(table_header)
print(table_separator)
for info in all_case_info:
(num_tokens, num_heads, head_size, dtype, device,
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
performance_improved) = info
(
num_tokens,
num_heads,
head_size,
dtype,
device,
avg_time_torch_kernel,
avg_time_triton_kernel,
avg_time_cuda_kernel,
performance_improved,
) = info
dtype = shortly_dtype(dtype)
device = shortly_device(device)
print(f"| {num_tokens} | {num_heads} | {head_size} "
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
f"| {avg_time_triton_kernel:.5f}ms "
f"| {avg_time_cuda_kernel:.5f}ms "
f"| {performance_improved:.4f}x |")
print(
f"| {num_tokens} | {num_heads} | {head_size} "
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
f"| {avg_time_triton_kernel:.5f}ms "
f"| {avg_time_cuda_kernel:.5f}ms "
f"| {performance_improved:.4f}x |"
)
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@@ -85,29 +96,28 @@ def generate_markdown_table():
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
head_size: int, output_dtype: torch.dtype):
def test_merge_attn_states(
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
):
if not current_platform.is_cuda():
pytest.skip('Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel')
pytest.skip(
"Currently only support compare triton merge_attn_states "
"with custom cuda merge_attn_states kernel"
)
NUM_TOKENS = num_tokens
NUM_HEADS = num_query_heads
HEAD_SIZE = head_size
print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"Device: {current_platform.get_device_name()}")
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"Device: {current_platform.get_device_name()}"
)
# prefix_lse and suffix_lse contain inf and normal values
prefix_lse = torch.randn(NUM_HEADS,
NUM_TOKENS,
dtype=torch.float32,
device="cuda")
suffix_lse = torch.randn(NUM_HEADS,
NUM_TOKENS,
dtype=torch.float32,
device="cuda")
prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
# Generate boolean masks
mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
@@ -117,23 +127,23 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
prefix_lse[mask_prefix] = float('inf')
suffix_lse[mask_suffix] = float('inf')
prefix_lse[mask_prefix] = float("inf")
suffix_lse[mask_suffix] = float("inf")
# Other input tensors (need to be initialized but
# no actual calculation needed)
output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS),
dtype=torch.float32,
device="cuda")
prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
output = torch.zeros(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
output_lse = torch.zeros(
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
)
prefix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
suffix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
warmup_times = 2
repeat_times = 20
@@ -149,15 +159,25 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
suffix_lse_torch = suffix_lse.clone()
for _ in range(warmup_times):
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch, prefix_output, prefix_lse_torch, suffix_output,
suffix_lse_torch, output_lse_torch)
output_torch,
prefix_output,
prefix_lse_torch,
suffix_output,
suffix_lse_torch,
output_lse_torch,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch, prefix_output, prefix_lse_torch, suffix_output,
suffix_lse_torch, output_lse_torch)
output_torch,
prefix_output,
prefix_lse_torch,
suffix_output,
suffix_lse_torch,
output_lse_torch,
)
end.record()
torch.cuda.synchronize()
total_time_torch_kernel += start.elapsed_time(end)
@@ -173,16 +193,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
end = torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times):
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
suffix_output, suffix_lse,
output_lse_ref_triton)
merge_attn_states_triton(
output_ref_triton,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_ref_triton,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
suffix_output, suffix_lse,
output_lse_ref_triton)
merge_attn_states_triton(
output_ref_triton,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_ref_triton,
)
end.record()
torch.cuda.synchronize()
total_time_triton_kernel += start.elapsed_time(end)
@@ -195,14 +225,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
output_lse_cuda = output_lse.clone()
for _ in range(warmup_times):
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse_cuda)
merge_attn_states_cuda(
output_cuda,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_cuda,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse_cuda)
merge_attn_states_cuda(
output_cuda,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_cuda,
)
end.record()
torch.cuda.synchronize()
total_time_cuda_kernel += start.elapsed_time(end)
@@ -213,8 +255,10 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel
print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
print(f"Triton time: {avg_time_triton_kernel:.6f}ms")
print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
f"Performance: {performance_improved:.5f}x")
print(
f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
f"Performance: {performance_improved:.5f}x"
)
print("-" * 100)
# 4. Correctness compare
@@ -232,35 +276,45 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
# states operation.
output_ref = output_ref_triton
output_lse_ref = output_lse_ref_triton
torch.testing.assert_close(output_cuda.float(),
output_ref.float(),
atol=1e-3,
rtol=rtol)
torch.testing.assert_close(
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
)
print("Output all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
print("-" * 100)
torch.testing.assert_close(output_lse_cuda.float(),
output_lse_ref.float(),
atol=1e-3,
rtol=rtol)
torch.testing.assert_close(
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
)
print("Output LSE all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}")
print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}")
print("-" * 100)
print("All output values test passed! All inf values "
"are correctly replaced with -inf.")
print(
"All output values test passed! All inf values "
"are correctly replaced with -inf."
)
print("-" * 100)
device = current_platform.get_device_name()
all_case_info.append(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device,
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
performance_improved))
if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) *
len(NUM_QUERY_HEADS) * len(DTYPES)):
(
NUM_TOKENS,
NUM_HEADS,
HEAD_SIZE,
output_dtype,
device,
avg_time_torch_kernel,
avg_time_triton_kernel,
avg_time_cuda_kernel,
performance_improved,
)
)
if len(all_case_info) == (
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
):
generate_markdown_table()