Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,8 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
linear_decode_forward_triton)
|
||||
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
@@ -17,8 +16,8 @@ DTYPES = [torch.float32]
|
||||
|
||||
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
"""Reference implementation of lightning attention core algorithm
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
each step sequentially, instead of using parallelized triton kernels
|
||||
"""
|
||||
B, H, S, D = q.shape
|
||||
@@ -62,8 +61,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
|
||||
# where dimension 2 contains both KV and KV history
|
||||
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
|
||||
dim=2) # [B, H, 2, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E]
|
||||
|
||||
return output, final_kv_cache
|
||||
|
||||
@@ -109,7 +107,7 @@ def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
|
||||
out_h = torch.matmul(q_bh, kv_new)
|
||||
|
||||
# Update output and cache
|
||||
output[b, h * D:(h + 1) * D] = out_h
|
||||
output[b, h * D : (h + 1) * D] = out_h
|
||||
kv_caches[b, h] = kv_new
|
||||
|
||||
return output
|
||||
@@ -135,12 +133,9 @@ def test_linear_decode_forward_triton(
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
@@ -150,15 +145,14 @@ def test_linear_decode_forward_triton(
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
torch.testing.assert_close(triton_output,
|
||||
reference_output,
|
||||
rtol=1e-1,
|
||||
atol=1e-1)
|
||||
reference_output = reference_linear_decode(
|
||||
q, k, v, kv_caches_copy, slope_rate, slot_idx
|
||||
)
|
||||
torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1)
|
||||
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
@@ -184,12 +178,9 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
@@ -199,14 +190,15 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
reference_output = reference_linear_decode(
|
||||
q, k, v, kv_caches_copy, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
padding_mask = (slot_idx
|
||||
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
|
||||
triton_masked = triton_output[padding_mask]
|
||||
reference_masked = reference_output[padding_mask]
|
||||
@@ -217,15 +209,11 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
|
||||
for i in range(batch_size):
|
||||
if valid_indices[i] > 0:
|
||||
torch.testing.assert_close(kv_caches[i],
|
||||
kv_caches_copy[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(
|
||||
kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
torch.testing.assert_close(triton_masked,
|
||||
reference_masked,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
@@ -249,39 +237,33 @@ def test_lightning_attention_reference(
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
base = 0.01
|
||||
q = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
kv_history = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
ref_output, ref_kv_cache = reference_lightning_attention(
|
||||
q, k, v, ed, 256, kv_history)
|
||||
q, k, v, ed, 256, kv_history
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import lightning_attention
|
||||
|
||||
actual_output, actual_kv_cache = lightning_attention(
|
||||
q, k, v, ed, 256, kv_history_clone)
|
||||
q, k, v, ed, 256, kv_history_clone
|
||||
)
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache,
|
||||
actual_kv_cache,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol)
|
||||
|
||||
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
|
||||
assert ref_kv_cache.shape == actual_kv_cache.shape
|
||||
|
||||
Reference in New Issue
Block a user