226 lines
7.1 KiB
Python
226 lines
7.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
|
|
|
|
|
|
def ref_masked_attention(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
is_causal: bool = True,
|
|
sliding_window_q: int | None = None,
|
|
sliding_window_k: int | None = None,
|
|
) -> torch.Tensor:
|
|
"""Reference implementation using PyTorch SDPA."""
|
|
# q, k, v: [total_tokens, num_heads, head_dim]
|
|
# SDPA expects [batch, num_heads, seq_len, head_dim]
|
|
|
|
total_tokens = q.shape[0]
|
|
|
|
# Add batch dimension and transpose
|
|
q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
|
|
k = k.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
|
|
v = v.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
|
|
|
|
# Create attention mask if needed
|
|
attn_mask = None
|
|
use_causal = is_causal
|
|
|
|
# If we have sliding window or need custom masking, create explicit mask
|
|
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
|
|
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
|
|
if (sliding_window_q > 0) or (sliding_window_k > 0):
|
|
# Position indices
|
|
pos_q = torch.arange(total_tokens, device=q.device).unsqueeze(1)
|
|
pos_k = torch.arange(total_tokens, device=q.device).unsqueeze(0)
|
|
|
|
# Start with valid mask (False = no masking)
|
|
mask = torch.ones(
|
|
(total_tokens, total_tokens), dtype=torch.bool, device=q.device
|
|
)
|
|
|
|
# Apply causal mask
|
|
if is_causal:
|
|
mask = mask & (pos_q >= pos_k)
|
|
|
|
# Apply sliding window masks
|
|
sliding_window_mask = torch.ones_like(mask)
|
|
if sliding_window_q > 0:
|
|
sliding_window_mask &= pos_q - pos_k <= sliding_window_q
|
|
|
|
if sliding_window_k > 0:
|
|
sliding_window_mask &= pos_k - pos_q <= sliding_window_k
|
|
|
|
mask = mask & sliding_window_mask
|
|
|
|
attn_mask = torch.where(mask, 0.0, float("-inf")).to(q.dtype)
|
|
use_causal = False # Don't use is_causal when providing explicit mask
|
|
|
|
# Use SDPA
|
|
output = F.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=attn_mask, is_causal=use_causal, dropout_p=0.0
|
|
)
|
|
|
|
# Convert back to original shape: [total_tokens, num_heads, head_dim]
|
|
output = output.transpose(1, 2).squeeze(0)
|
|
|
|
return output
|
|
|
|
|
|
@pytest.mark.parametrize("B", [5])
|
|
@pytest.mark.parametrize("max_seq_len", [1024])
|
|
@pytest.mark.parametrize("H_Q", [32])
|
|
@pytest.mark.parametrize("H_KV", [32, 8])
|
|
@pytest.mark.parametrize("D", [128])
|
|
@pytest.mark.parametrize("is_causal", [True, False])
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
|
def test_context_attention(
|
|
B: int,
|
|
max_seq_len: int,
|
|
H_Q: int,
|
|
H_KV: int,
|
|
D: int,
|
|
is_causal: bool,
|
|
dtype: torch.dtype,
|
|
):
|
|
"""Test basic context attention without sliding window."""
|
|
torch.manual_seed(42)
|
|
|
|
# Generate random sequence lengths for each batch
|
|
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
|
|
total_tokens = seq_lens.sum().item()
|
|
|
|
# Create batch start locations
|
|
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
|
|
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
|
|
|
# Create input tensors
|
|
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
|
|
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
|
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
|
o = torch.zeros_like(q)
|
|
|
|
# Call Triton kernel
|
|
context_attention_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
b_start_loc,
|
|
seq_lens,
|
|
max_seq_len,
|
|
is_causal=is_causal,
|
|
sliding_window_q=None,
|
|
sliding_window_k=None,
|
|
)
|
|
|
|
# Compute reference output for each sequence in batch
|
|
o_ref = torch.zeros_like(q)
|
|
for i in range(B):
|
|
start = b_start_loc[i].item()
|
|
end = start + seq_lens[i].item()
|
|
|
|
q_seq = q[start:end]
|
|
k_seq = k[start:end]
|
|
v_seq = v[start:end]
|
|
|
|
# Expand KV heads if using GQA
|
|
if H_Q != H_KV:
|
|
kv_group_num = H_Q // H_KV
|
|
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
|
|
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
|
|
|
|
o_ref[start:end] = ref_masked_attention(
|
|
q_seq,
|
|
k_seq,
|
|
v_seq,
|
|
is_causal=is_causal,
|
|
sliding_window_q=None,
|
|
sliding_window_k=None,
|
|
)
|
|
|
|
# Compare outputs
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
|
|
|
|
|
|
@pytest.mark.parametrize("B", [4])
|
|
@pytest.mark.parametrize("max_seq_len", [1024])
|
|
@pytest.mark.parametrize("H_Q", [32])
|
|
@pytest.mark.parametrize("H_KV", [32, 8])
|
|
@pytest.mark.parametrize("D", [128])
|
|
@pytest.mark.parametrize("sliding_window", [(32, 32), (32, 0), (0, 32)])
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
|
def test_context_attention_sliding_window(
|
|
B: int,
|
|
max_seq_len: int,
|
|
H_Q: int,
|
|
H_KV: int,
|
|
D: int,
|
|
sliding_window: tuple[int, int],
|
|
dtype: torch.dtype,
|
|
):
|
|
sliding_window_q, sliding_window_k = sliding_window
|
|
"""Test context attention with sliding window."""
|
|
torch.manual_seed(42)
|
|
|
|
# Generate random sequence lengths for each batch
|
|
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
|
|
total_tokens = seq_lens.sum().item()
|
|
|
|
# Create batch start locations
|
|
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
|
|
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
|
|
|
# Create input tensors
|
|
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
|
|
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
|
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
|
o = torch.zeros_like(q)
|
|
|
|
# Call Triton kernel
|
|
context_attention_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
b_start_loc,
|
|
seq_lens,
|
|
max_seq_len,
|
|
is_causal=False,
|
|
sliding_window_q=sliding_window_q,
|
|
sliding_window_k=sliding_window_k,
|
|
)
|
|
|
|
# Compute reference output for each sequence in batch
|
|
o_ref = torch.zeros_like(q)
|
|
for i in range(B):
|
|
start = b_start_loc[i].item()
|
|
end = start + seq_lens[i].item()
|
|
|
|
q_seq = q[start:end]
|
|
k_seq = k[start:end]
|
|
v_seq = v[start:end]
|
|
|
|
# Expand KV heads if using GQA
|
|
if H_Q != H_KV:
|
|
kv_group_num = H_Q // H_KV
|
|
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
|
|
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
|
|
|
|
o_ref[start:end] = ref_masked_attention(
|
|
q_seq,
|
|
k_seq,
|
|
v_seq,
|
|
is_causal=False,
|
|
sliding_window_q=sliding_window_q if sliding_window_q > 0 else None,
|
|
sliding_window_k=sliding_window_k if sliding_window_k > 0 else None,
|
|
)
|
|
|
|
# Compare outputs
|
|
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
|