Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -4,8 +4,7 @@
import pytest
import torch
from vllm.model_executor.layers.lightning_attn import (
linear_decode_forward_triton)
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
from vllm.platforms import current_platform
NUM_HEADS = [4, 8]
@@ -17,8 +16,8 @@ DTYPES = [torch.float32]
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
"""Reference implementation of lightning attention core algorithm
The difference from the main implementation is that this processes
The difference from the main implementation is that this processes
each step sequentially, instead of using parallelized triton kernels
"""
B, H, S, D = q.shape
@@ -62,8 +61,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
# where dimension 2 contains both KV and KV history
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
dim=2) # [B, H, 2, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E]
return output, final_kv_cache
@@ -109,7 +107,7 @@ def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
out_h = torch.matmul(q_bh, kv_new)
# Update output and cache
output[b, h * D:(h + 1) * D] = out_h
output[b, h * D : (h + 1) * D] = out_h
kv_caches[b, h] = kv_new
return output
@@ -135,12 +133,9 @@ def test_linear_decode_forward_triton(
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
)
kv_caches_copy = kv_caches.clone()
@@ -150,15 +145,14 @@ def test_linear_decode_forward_triton(
slot_idx = torch.arange(batch_size, device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
torch.testing.assert_close(triton_output,
reference_output,
rtol=1e-1,
atol=1e-1)
reference_output = reference_linear_decode(
q, k, v, kv_caches_copy, slope_rate, slot_idx
)
torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1)
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
assert triton_output.shape == (batch_size, num_heads * head_size)
@@ -184,12 +178,9 @@ def test_linear_decode_forward_triton_with_padding(
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
)
kv_caches_copy = kv_caches.clone()
@@ -199,14 +190,15 @@ def test_linear_decode_forward_triton_with_padding(
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
reference_output = reference_linear_decode(
q, k, v, kv_caches_copy, slope_rate, slot_idx
)
padding_mask = (slot_idx
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size)
triton_masked = triton_output[padding_mask]
reference_masked = reference_output[padding_mask]
@@ -217,15 +209,11 @@ def test_linear_decode_forward_triton_with_padding(
for i in range(batch_size):
if valid_indices[i] > 0:
torch.testing.assert_close(kv_caches[i],
kv_caches_copy[i],
rtol=rtol,
atol=atol)
torch.testing.assert_close(
kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol
)
torch.testing.assert_close(triton_masked,
reference_masked,
rtol=rtol,
atol=atol)
torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol)
assert triton_output.shape == (batch_size, num_heads * head_size)
@@ -249,39 +237,33 @@ def test_lightning_attention_reference(
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_history = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
)
kv_history_clone = kv_history.clone()
ref_output, ref_kv_cache = reference_lightning_attention(
q, k, v, ed, 256, kv_history)
q, k, v, ed, 256, kv_history
)
from vllm.model_executor.layers.lightning_attn import lightning_attention
actual_output, actual_kv_cache = lightning_attention(
q, k, v, ed, 256, kv_history_clone)
q, k, v, ed, 256, kv_history_clone
)
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
torch.testing.assert_close(ref_kv_cache,
actual_kv_cache,
rtol=rtol,
atol=atol)
torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol)
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
assert ref_kv_cache.shape == actual_kv_cache.shape