[v1] Add encoder-only/cross attention support to Triton Attention backend (#31406)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
225
tests/kernels/attention/test_triton_prefill_attention.py
Normal file
225
tests/kernels/attention/test_triton_prefill_attention.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
is_causal: bool = True,
|
||||
sliding_window_q: int | None = None,
|
||||
sliding_window_k: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Reference implementation using PyTorch SDPA."""
|
||||
# q, k, v: [total_tokens, num_heads, head_dim]
|
||||
# SDPA expects [batch, num_heads, seq_len, head_dim]
|
||||
|
||||
total_tokens = q.shape[0]
|
||||
|
||||
# Add batch dimension and transpose
|
||||
q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
|
||||
k = k.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
|
||||
v = v.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
|
||||
|
||||
# Create attention mask if needed
|
||||
attn_mask = None
|
||||
use_causal = is_causal
|
||||
|
||||
# If we have sliding window or need custom masking, create explicit mask
|
||||
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
|
||||
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
|
||||
if (sliding_window_q > 0) or (sliding_window_k > 0):
|
||||
# Position indices
|
||||
pos_q = torch.arange(total_tokens, device=q.device).unsqueeze(1)
|
||||
pos_k = torch.arange(total_tokens, device=q.device).unsqueeze(0)
|
||||
|
||||
# Start with valid mask (False = no masking)
|
||||
mask = torch.ones(
|
||||
(total_tokens, total_tokens), dtype=torch.bool, device=q.device
|
||||
)
|
||||
|
||||
# Apply causal mask
|
||||
if is_causal:
|
||||
mask = mask & (pos_q >= pos_k)
|
||||
|
||||
# Apply sliding window masks
|
||||
sliding_window_mask = torch.ones_like(mask)
|
||||
if sliding_window_q > 0:
|
||||
sliding_window_mask &= pos_q - pos_k <= sliding_window_q
|
||||
|
||||
if sliding_window_k > 0:
|
||||
sliding_window_mask &= pos_k - pos_q <= sliding_window_k
|
||||
|
||||
mask = mask & sliding_window_mask
|
||||
|
||||
attn_mask = torch.where(mask, 0.0, float("-inf")).to(q.dtype)
|
||||
use_causal = False # Don't use is_causal when providing explicit mask
|
||||
|
||||
# Use SDPA
|
||||
output = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=use_causal, dropout_p=0.0
|
||||
)
|
||||
|
||||
# Convert back to original shape: [total_tokens, num_heads, head_dim]
|
||||
output = output.transpose(1, 2).squeeze(0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B", [5])
|
||||
@pytest.mark.parametrize("max_seq_len", [1024])
|
||||
@pytest.mark.parametrize("H_Q", [32])
|
||||
@pytest.mark.parametrize("H_KV", [32, 8])
|
||||
@pytest.mark.parametrize("D", [128])
|
||||
@pytest.mark.parametrize("is_causal", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
def test_context_attention(
|
||||
B: int,
|
||||
max_seq_len: int,
|
||||
H_Q: int,
|
||||
H_KV: int,
|
||||
D: int,
|
||||
is_causal: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Test basic context attention without sliding window."""
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Generate random sequence lengths for each batch
|
||||
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
|
||||
total_tokens = seq_lens.sum().item()
|
||||
|
||||
# Create batch start locations
|
||||
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
|
||||
# Create input tensors
|
||||
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
|
||||
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
o = torch.zeros_like(q)
|
||||
|
||||
# Call Triton kernel
|
||||
context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
b_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
is_causal=is_causal,
|
||||
sliding_window_q=None,
|
||||
sliding_window_k=None,
|
||||
)
|
||||
|
||||
# Compute reference output for each sequence in batch
|
||||
o_ref = torch.zeros_like(q)
|
||||
for i in range(B):
|
||||
start = b_start_loc[i].item()
|
||||
end = start + seq_lens[i].item()
|
||||
|
||||
q_seq = q[start:end]
|
||||
k_seq = k[start:end]
|
||||
v_seq = v[start:end]
|
||||
|
||||
# Expand KV heads if using GQA
|
||||
if H_Q != H_KV:
|
||||
kv_group_num = H_Q // H_KV
|
||||
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
|
||||
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
|
||||
|
||||
o_ref[start:end] = ref_masked_attention(
|
||||
q_seq,
|
||||
k_seq,
|
||||
v_seq,
|
||||
is_causal=is_causal,
|
||||
sliding_window_q=None,
|
||||
sliding_window_k=None,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B", [4])
|
||||
@pytest.mark.parametrize("max_seq_len", [1024])
|
||||
@pytest.mark.parametrize("H_Q", [32])
|
||||
@pytest.mark.parametrize("H_KV", [32, 8])
|
||||
@pytest.mark.parametrize("D", [128])
|
||||
@pytest.mark.parametrize("sliding_window", [(32, 32), (32, 0), (0, 32)])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
def test_context_attention_sliding_window(
|
||||
B: int,
|
||||
max_seq_len: int,
|
||||
H_Q: int,
|
||||
H_KV: int,
|
||||
D: int,
|
||||
sliding_window: tuple[int, int],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
sliding_window_q, sliding_window_k = sliding_window
|
||||
"""Test context attention with sliding window."""
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Generate random sequence lengths for each batch
|
||||
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
|
||||
total_tokens = seq_lens.sum().item()
|
||||
|
||||
# Create batch start locations
|
||||
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
|
||||
# Create input tensors
|
||||
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
|
||||
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
o = torch.zeros_like(q)
|
||||
|
||||
# Call Triton kernel
|
||||
context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
b_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
is_causal=False,
|
||||
sliding_window_q=sliding_window_q,
|
||||
sliding_window_k=sliding_window_k,
|
||||
)
|
||||
|
||||
# Compute reference output for each sequence in batch
|
||||
o_ref = torch.zeros_like(q)
|
||||
for i in range(B):
|
||||
start = b_start_loc[i].item()
|
||||
end = start + seq_lens[i].item()
|
||||
|
||||
q_seq = q[start:end]
|
||||
k_seq = k[start:end]
|
||||
v_seq = v[start:end]
|
||||
|
||||
# Expand KV heads if using GQA
|
||||
if H_Q != H_KV:
|
||||
kv_group_num = H_Q // H_KV
|
||||
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
|
||||
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
|
||||
|
||||
o_ref[start:end] = ref_masked_attention(
|
||||
q_seq,
|
||||
k_seq,
|
||||
v_seq,
|
||||
is_causal=False,
|
||||
sliding_window_q=sliding_window_q if sliding_window_q > 0 else None,
|
||||
sliding_window_k=sliding_window_k if sliding_window_k > 0 else None,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
|
||||
@@ -114,7 +114,7 @@ def check_model_available(model: str) -> None:
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@create_new_process_for_each_test("spawn")
|
||||
|
||||
@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
@@ -83,6 +84,13 @@ BATCH_SPECS = {
|
||||
),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
# encoder-only
|
||||
"small_encoder_prefill": BatchSpec(
|
||||
seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256]
|
||||
),
|
||||
"medium_encoder_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -209,6 +217,7 @@ def run_attention_backend(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
sliding_window: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
@@ -276,6 +285,7 @@ def run_attention_backend(
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
@@ -299,6 +309,7 @@ def _test_backend_correctness(
|
||||
backend_to_test: list[AttentionBackendEnum | str],
|
||||
mask_mod,
|
||||
*,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
block_size: int = 16,
|
||||
atol: float = 1e-2,
|
||||
rtol: float = 1e-2,
|
||||
@@ -436,6 +447,9 @@ def _test_backend_correctness(
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
# For encoder-only, all tokens are prefill tokens
|
||||
common_attn_metadata.causal = False
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
@@ -491,6 +505,7 @@ def _test_backend_correctness(
|
||||
value_vllm,
|
||||
kv_cache_for_backend,
|
||||
sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
finally:
|
||||
if reset_kv_cache_layout:
|
||||
@@ -676,3 +691,45 @@ def test_sliding_window_backend_correctness(
|
||||
block_size=128,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_encoder_prefill",
|
||||
"medium_encoder_prefill",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
|
||||
def test_sliding_window_encoder_backend_correctness(
|
||||
batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""Test backend's correctness with sliding window attention."""
|
||||
|
||||
def bidi_sliding_window_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
*,
|
||||
context_len: int,
|
||||
sliding_window: int,
|
||||
):
|
||||
return torch.abs(q_idx + context_len - kv_idx) < sliding_window
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
sliding_window_mask_mod_fn = partial(
|
||||
bidi_sliding_window_mask_mod, sliding_window=sliding_window
|
||||
)
|
||||
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST,
|
||||
sliding_window_mask_mod_fn,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
271
vllm/attention/ops/triton_prefill_attention.py
Normal file
271
vllm/attention/ops/triton_prefill_attention.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/97cb762bb65ebf05025eb342de03c184660427a3/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
|
||||
# Changes:
|
||||
# - Add support for sliding window attention
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It supports page size = 1.
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
SLIDING_WINDOW_Q: tl.constexpr,
|
||||
SLIDING_WINDOW_K: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||
|
||||
mask_d = offs_d < Lk
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
# Calculate the end position for attention computation
|
||||
end_n = cur_batch_seq_len
|
||||
|
||||
# Apply causal attention pruning and sliding window attention pruning
|
||||
end_n = tl.minimum(end_n, (start_m + 1) * BLOCK_M) if IS_CAUSAL else end_n
|
||||
|
||||
# Calculate the start position for backward sliding window
|
||||
start_n_limit = 0
|
||||
end_n_limit = block_mask * end_n
|
||||
|
||||
for start_n in range(start_n_limit, end_n_limit, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Apply attention mask (causal + bidirectional sliding window)
|
||||
# Position indices in the sequence
|
||||
pos_q = offs_m[:, None] # Query positions [BLOCK_M, 1]
|
||||
pos_k = start_n + offs_n[None, :] # Key positions [1, BLOCK_N]
|
||||
|
||||
# Valid sequence mask
|
||||
mask = pos_k < cur_batch_seq_len
|
||||
# Causal mask
|
||||
if IS_CAUSAL:
|
||||
mask &= pos_q >= pos_k
|
||||
|
||||
# Bidirectional sliding window masks
|
||||
sliding_mask_q = (
|
||||
pos_q - pos_k <= SLIDING_WINDOW_Q if SLIDING_WINDOW_Q > 0 else None
|
||||
)
|
||||
sliding_mask_k = (
|
||||
pos_k - pos_q <= SLIDING_WINDOW_K if SLIDING_WINDOW_K > 0 else None
|
||||
)
|
||||
if sliding_mask_q is not None:
|
||||
mask &= sliding_mask_q
|
||||
if sliding_mask_k is not None:
|
||||
mask &= sliding_mask_k
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.where(mask, 0, float("-inf"))
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_ij_valid_mask = m_ij > float("-inf")
|
||||
m_ij_masked = tl.where(m_ij_valid_mask, m_ij, 0.0)
|
||||
# -- compute p and l_ij --
|
||||
p = tl.exp(qk - m_ij_masked[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
m_i_new_mask = m_i_new > float("-inf")
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
# mask alpha and beta for sliding window
|
||||
alpha = tl.where(m_i_new_mask, alpha, 1.0)
|
||||
beta = tl.where(m_i_new_mask, beta, 0.0)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# For sliding window there's a chance the l_i_new is 0 due to masking
|
||||
# the entire row. We need to set l_i_new 1 to avoid zero division
|
||||
l_i_new_mask = (l_i_new != 0.0) & (m_i_new_mask > float("-inf"))
|
||||
l_i_new_safe = tl.where(l_i_new_mask, l_i_new, 1.0)
|
||||
p_scale = beta / l_i_new_safe
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new_safe * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(
|
||||
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
|
||||
)
|
||||
|
||||
|
||||
def get_block_size(dtype: torch.dtype) -> int:
|
||||
if dtype == torch.float32:
|
||||
return 32
|
||||
elif (
|
||||
current_platform.is_cuda_alike()
|
||||
) and current_platform.get_device_capability().major > 8:
|
||||
return 128
|
||||
else:
|
||||
return 64
|
||||
|
||||
|
||||
def context_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_input_len,
|
||||
is_causal=True,
|
||||
sliding_window_q=None,
|
||||
sliding_window_k=None,
|
||||
):
|
||||
"""
|
||||
q, k, v: [b * s, head, head_dim]
|
||||
b_start_loc: [b]
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
BLOCK = get_block_size(q.dtype)
|
||||
|
||||
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
|
||||
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
||||
BLOCK_N=BLOCK,
|
||||
IS_CAUSAL=is_causal,
|
||||
SLIDING_WINDOW_Q=sliding_window_q,
|
||||
SLIDING_WINDOW_K=sliding_window_k,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
Lk=Lk,
|
||||
)
|
||||
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
@@ -289,14 +288,6 @@ class RocmPlatform(Platform):
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Priority 5: If model is Encoder-only self-attention type
|
||||
if (
|
||||
attn_selector_config.attn_type is not None
|
||||
and attn_selector_config.attn_type == AttentionType.ENCODER_ONLY
|
||||
):
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
|
||||
|
||||
# Default: Triton Unified Attention
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
@@ -309,6 +310,16 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""TritonAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
@@ -341,6 +352,8 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
elif attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):
|
||||
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
@@ -352,10 +365,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention is not implemented for TritonAttentionImpl"
|
||||
)
|
||||
self.attn_type = attn_type
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
@@ -417,6 +426,21 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if (
|
||||
@@ -495,3 +519,48 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention"
|
||||
)
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_input_len=max_query_len,
|
||||
is_causal=False, # Encoder attention is bidirectional
|
||||
sliding_window_q=self.sliding_window[0],
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user