Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,9 +9,21 @@ from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, CBLOCK: tl.constexpr):
|
||||
def _fwd_diag_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
S,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
CBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel computes the diagonal blocks of the attention matrix
|
||||
# Each diagonal block represents attention
|
||||
# where queries attend to keys in the same block
|
||||
@@ -39,18 +51,36 @@ def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||
o_cblock_offset = cblock_offset * e
|
||||
|
||||
# Calculate pointers to the query, key, value, and output tensors
|
||||
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, d)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, e)[None, :])
|
||||
Q_block_ptr = (
|
||||
Q
|
||||
+ qk_offset
|
||||
+ qk_block_offset
|
||||
+ q_cblock_offset
|
||||
+ tl.arange(0, CBLOCK)[:, None] * d
|
||||
+ tl.arange(0, d)[None, :]
|
||||
)
|
||||
K_trans_block_ptr = (
|
||||
K
|
||||
+ qk_offset
|
||||
+ qk_block_offset
|
||||
+ tl.arange(0, CBLOCK)[None, :] * d
|
||||
+ tl.arange(0, d)[:, None]
|
||||
)
|
||||
V_block_ptr = (
|
||||
V
|
||||
+ v_offset
|
||||
+ v_block_offset
|
||||
+ tl.arange(0, CBLOCK)[:, None] * e
|
||||
+ tl.arange(0, e)[None, :]
|
||||
)
|
||||
O_block_ptr = (
|
||||
Out
|
||||
+ o_offset
|
||||
+ o_block_offset
|
||||
+ o_cblock_offset
|
||||
+ tl.arange(0, CBLOCK)[:, None] * e
|
||||
+ tl.arange(0, e)[None, :]
|
||||
)
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
@@ -60,9 +90,9 @@ def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr,
|
||||
mask=block_offset + q_index[:, None] < n,
|
||||
other=0.0).to(tl.float32)
|
||||
q = tl.load(Q_block_ptr, mask=block_offset + q_index[:, None] < n, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
|
||||
# Initialize output accumulator
|
||||
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
|
||||
@@ -146,18 +176,30 @@ def _fwd_kv_parallel(
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointers to the key, value, and key-value tensors
|
||||
K_trans_block_ptr = (K + k_offset + k_block_offset +
|
||||
tl.arange(0, CBLOCK)[None, :] * d +
|
||||
tl.arange(0, D_FBLOCK)[:, None])
|
||||
V_block_ptr = (V + v_offset + v_block_offset +
|
||||
tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + kv_block_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
K_trans_block_ptr = (
|
||||
K
|
||||
+ k_offset
|
||||
+ k_block_offset
|
||||
+ tl.arange(0, CBLOCK)[None, :] * d
|
||||
+ tl.arange(0, D_FBLOCK)[:, None]
|
||||
)
|
||||
V_block_ptr = (
|
||||
V
|
||||
+ v_offset
|
||||
+ v_block_offset
|
||||
+ tl.arange(0, CBLOCK)[:, None] * e
|
||||
+ tl.arange(0, E_FBLOCK)[None, :]
|
||||
)
|
||||
KV_block_ptr = (
|
||||
KV
|
||||
+ kv_offset
|
||||
+ kv_block_offset
|
||||
+ tl.arange(0, D_FBLOCK)[:, None] * e
|
||||
+ tl.arange(0, E_FBLOCK)[None, :]
|
||||
)
|
||||
|
||||
# Load the decay factors for the current head and block
|
||||
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
|
||||
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]
|
||||
|
||||
kv_index = tl.arange(0, CBLOCK)
|
||||
|
||||
@@ -177,12 +219,16 @@ def _fwd_kv_parallel(
|
||||
for j in range(num_blocks):
|
||||
left_bound = (1 - j) * left_shift
|
||||
# Load key and value, handling boundary conditions
|
||||
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
|
||||
mask=kv_index[None, :] >= left_bound,
|
||||
other=0.0)
|
||||
v = tl.load(V_block_ptr - left_shift * e,
|
||||
mask=kv_index[:, None] >= left_bound,
|
||||
other=0.0)
|
||||
k_trans = tl.load(
|
||||
K_trans_block_ptr - left_shift * d,
|
||||
mask=kv_index[None, :] >= left_bound,
|
||||
other=0.0,
|
||||
)
|
||||
v = tl.load(
|
||||
V_block_ptr - left_shift * e,
|
||||
mask=kv_index[:, None] >= left_bound,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Load decay factor and compute weighted key-value outer product
|
||||
k_decay = tl.load(k_decay_ptr)
|
||||
@@ -198,9 +244,20 @@ def _fwd_kv_parallel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
|
||||
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
|
||||
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
|
||||
def _fwd_kv_reduce(
|
||||
S,
|
||||
KV,
|
||||
KV_HISTORY,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n,
|
||||
d: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_BLOCK,
|
||||
D_FBLOCK: tl.constexpr,
|
||||
E_FBLOCK: tl.constexpr,
|
||||
):
|
||||
# This kernel reduces the key-value outer products
|
||||
# across blocks and updates the KV history
|
||||
off_bh = tl.program_id(0) # batch-head index
|
||||
@@ -209,8 +266,12 @@ def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e
|
||||
|
||||
# Calculate pointer to the key-value tensor
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (
|
||||
KV
|
||||
+ kv_offset
|
||||
+ tl.arange(0, D_FBLOCK)[:, None] * e
|
||||
+ tl.arange(0, E_FBLOCK)[None, :]
|
||||
)
|
||||
|
||||
# Load the decay rate for the current head
|
||||
s_ptrs = S + off_h
|
||||
@@ -218,9 +279,12 @@ def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
|
||||
|
||||
# Calculate pointer to the key-value history tensor
|
||||
kv_history_offset = off_bh * d * e
|
||||
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
|
||||
tl.arange(0, D_FBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_HISTORY_block_ptr = (
|
||||
KV_HISTORY
|
||||
+ kv_history_offset
|
||||
+ tl.arange(0, D_FBLOCK)[:, None] * e
|
||||
+ tl.arange(0, E_FBLOCK)[None, :]
|
||||
)
|
||||
|
||||
# Load the previous key-value history
|
||||
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
|
||||
@@ -283,12 +347,18 @@ def _fwd_none_diag_kernel(
|
||||
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
|
||||
|
||||
# Calculate pointers to the query, output, and key-value tensors
|
||||
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
|
||||
tl.arange(0, d)[None, :])
|
||||
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
|
||||
tl.arange(0, E_FBLOCK)[None, :])
|
||||
Q_block_ptr = (
|
||||
Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]
|
||||
)
|
||||
O_block_ptr = (
|
||||
Out
|
||||
+ o_offset
|
||||
+ tl.arange(0, CBLOCK)[:, None] * e
|
||||
+ tl.arange(0, E_FBLOCK)[None, :]
|
||||
)
|
||||
KV_block_ptr = (
|
||||
KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
|
||||
)
|
||||
|
||||
# Load the decay rate for the current head
|
||||
S_block_ptr = S + off_h
|
||||
@@ -301,8 +371,7 @@ def _fwd_none_diag_kernel(
|
||||
q_index = block_offset + tl.arange(0, CBLOCK)
|
||||
|
||||
# Load query values
|
||||
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute decay factors for the current sub-block
|
||||
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
|
||||
@@ -311,20 +380,18 @@ def _fwd_none_diag_kernel(
|
||||
qkv_none_diag = tl.dot(q, kv) * q_decay
|
||||
|
||||
# Load diagonal attention output (computed by _fwd_diag_kernel)
|
||||
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
|
||||
other=0.).to(tl.float32)
|
||||
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)
|
||||
|
||||
# Combine diagonal and non-diagonal attention outputs
|
||||
qkv = qkv_diag + qkv_none_diag
|
||||
|
||||
# Store the result
|
||||
tl.store(O_block_ptr,
|
||||
qkv.to(O_block_ptr.dtype.element_ty),
|
||||
mask=q_index[:, None] < n)
|
||||
tl.store(
|
||||
O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n
|
||||
)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, s, kv_history):
|
||||
# Forward pass of the lightning attention algorithm
|
||||
@@ -336,8 +403,10 @@ class _attention(torch.autograd.Function):
|
||||
# Check CUDA compute capability
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported",
|
||||
"for compute capability >= 80")
|
||||
raise RuntimeError(
|
||||
"Flash attention currently only supported",
|
||||
"for compute capability >= 80",
|
||||
)
|
||||
|
||||
# Get input dimensions
|
||||
b, h, n, d = q.shape
|
||||
@@ -360,19 +429,21 @@ class _attention(torch.autograd.Function):
|
||||
|
||||
# Step 1: Compute diagonal blocks of attention
|
||||
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
|
||||
_fwd_diag_kernel[grid](q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
s,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
CBLOCK=CBLOCK)
|
||||
_fwd_diag_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
s,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
CBLOCK=CBLOCK,
|
||||
)
|
||||
|
||||
# Set feature block sizes
|
||||
NUM_FBLOCK = 1
|
||||
@@ -386,9 +457,7 @@ class _attention(torch.autograd.Function):
|
||||
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
|
||||
|
||||
# Step 2: Compute key-value outer products for each block in parallel
|
||||
kv = torch.empty((b, h, NUM_BLOCK, d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
kv = torch.empty((b, h, NUM_BLOCK, d, e), dtype=torch.float32, device=q.device)
|
||||
grid = (b * h, NUM_BLOCK)
|
||||
_fwd_kv_parallel[grid](
|
||||
k,
|
||||
@@ -412,18 +481,20 @@ class _attention(torch.autograd.Function):
|
||||
# Step 3: Reduce key-value outer products
|
||||
# across blocks and update KV history
|
||||
grid = (b * h, NUM_FBLOCK)
|
||||
_fwd_kv_reduce[grid](s,
|
||||
kv,
|
||||
kv_history,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK)
|
||||
_fwd_kv_reduce[grid](
|
||||
s,
|
||||
kv,
|
||||
kv_history,
|
||||
b,
|
||||
h,
|
||||
n,
|
||||
d,
|
||||
e,
|
||||
BLOCK=BLOCK,
|
||||
NUM_BLOCK=NUM_BLOCK,
|
||||
D_FBLOCK=D_FBLOCK,
|
||||
E_FBLOCK=E_FBLOCK,
|
||||
)
|
||||
|
||||
# Step 4: Compute non-diagonal blocks of attention
|
||||
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
|
||||
@@ -461,12 +532,12 @@ def lightning_attention(
|
||||
v: torch.Tensor,
|
||||
ed: torch.Tensor,
|
||||
block_size: int = 256,
|
||||
kv_history: Optional[torch.Tensor] = None
|
||||
kv_history: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply lightning attention algorithm
|
||||
Apply lightning attention algorithm
|
||||
to compute attention efficiently.
|
||||
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [batch, heads, seq_len, dim]
|
||||
k: Key tensor of shape [batch, heads, seq_len, dim]
|
||||
@@ -474,7 +545,7 @@ def lightning_attention(
|
||||
ed: Decay rate tensor of shape [heads]
|
||||
block_size: Size of blocks for block-sparse attention
|
||||
kv_history: Optional key-value history from previous computations
|
||||
|
||||
|
||||
Returns:
|
||||
output: Attention output
|
||||
kv: Updated key-value history
|
||||
@@ -496,9 +567,9 @@ def lightning_attention(
|
||||
|
||||
# Initialize or clone key-value history
|
||||
if kv_history is None:
|
||||
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
|
||||
dtype=torch.float32,
|
||||
device=q.device)
|
||||
kv_history = torch.zeros(
|
||||
(q.shape[0], q.shape[1], d, e), dtype=torch.float32, device=q.device
|
||||
)
|
||||
else:
|
||||
kv_history = kv_history.clone().contiguous()
|
||||
|
||||
@@ -533,7 +604,7 @@ def _linear_attn_decode_kernel(
|
||||
):
|
||||
"""
|
||||
Kernel for linear attention decoding with KV cache.
|
||||
|
||||
|
||||
This kernel computes attention for a single token using the KV cache.
|
||||
"""
|
||||
pid_b = tl.program_id(0) # batch index
|
||||
@@ -556,8 +627,9 @@ def _linear_attn_decode_kernel(
|
||||
# Calculate offsets for dimensions
|
||||
qk_d_offsets = tl.arange(0, D)
|
||||
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
|
||||
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
|
||||
None, :] * cache_d1_stride
|
||||
cache_d_offsets = (
|
||||
qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[None, :] * cache_d1_stride
|
||||
)
|
||||
|
||||
# Calculate offsets for the current batch and head
|
||||
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
|
||||
@@ -605,7 +677,7 @@ def linear_decode_forward_triton(
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform linear attention decoding using Triton kernels.
|
||||
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [B, H, 1, D]
|
||||
k: Key tensor of shape [B, H, 1, D]
|
||||
@@ -614,7 +686,7 @@ def linear_decode_forward_triton(
|
||||
slope_rate: Decay rate tensor
|
||||
slot_idx: Slot indices for batches
|
||||
BLOCK_SIZE: Size of blocks for processing
|
||||
|
||||
|
||||
Returns:
|
||||
output: Attention output tensor
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user