[v1] Add encoder-only/cross attention support to Triton Attention backend (#31406)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-06 00:00:23 +08:00
committed by GitHub
parent 911d38ed99
commit 6aa5b18e1d
6 changed files with 627 additions and 14 deletions

View File

@@ -0,0 +1,225 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
def ref_masked_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
is_causal: bool = True,
sliding_window_q: int | None = None,
sliding_window_k: int | None = None,
) -> torch.Tensor:
"""Reference implementation using PyTorch SDPA."""
# q, k, v: [total_tokens, num_heads, head_dim]
# SDPA expects [batch, num_heads, seq_len, head_dim]
total_tokens = q.shape[0]
# Add batch dimension and transpose
q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
k = k.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
v = v.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
# Create attention mask if needed
attn_mask = None
use_causal = is_causal
# If we have sliding window or need custom masking, create explicit mask
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
if (sliding_window_q > 0) or (sliding_window_k > 0):
# Position indices
pos_q = torch.arange(total_tokens, device=q.device).unsqueeze(1)
pos_k = torch.arange(total_tokens, device=q.device).unsqueeze(0)
# Start with valid mask (False = no masking)
mask = torch.ones(
(total_tokens, total_tokens), dtype=torch.bool, device=q.device
)
# Apply causal mask
if is_causal:
mask = mask & (pos_q >= pos_k)
# Apply sliding window masks
sliding_window_mask = torch.ones_like(mask)
if sliding_window_q > 0:
sliding_window_mask &= pos_q - pos_k <= sliding_window_q
if sliding_window_k > 0:
sliding_window_mask &= pos_k - pos_q <= sliding_window_k
mask = mask & sliding_window_mask
attn_mask = torch.where(mask, 0.0, float("-inf")).to(q.dtype)
use_causal = False # Don't use is_causal when providing explicit mask
# Use SDPA
output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=use_causal, dropout_p=0.0
)
# Convert back to original shape: [total_tokens, num_heads, head_dim]
output = output.transpose(1, 2).squeeze(0)
return output
@pytest.mark.parametrize("B", [5])
@pytest.mark.parametrize("max_seq_len", [1024])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D", [128])
@pytest.mark.parametrize("is_causal", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_context_attention(
B: int,
max_seq_len: int,
H_Q: int,
H_KV: int,
D: int,
is_causal: bool,
dtype: torch.dtype,
):
"""Test basic context attention without sliding window."""
torch.manual_seed(42)
# Generate random sequence lengths for each batch
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
total_tokens = seq_lens.sum().item()
# Create batch start locations
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
# Create input tensors
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
o = torch.zeros_like(q)
# Call Triton kernel
context_attention_fwd(
q,
k,
v,
o,
b_start_loc,
seq_lens,
max_seq_len,
is_causal=is_causal,
sliding_window_q=None,
sliding_window_k=None,
)
# Compute reference output for each sequence in batch
o_ref = torch.zeros_like(q)
for i in range(B):
start = b_start_loc[i].item()
end = start + seq_lens[i].item()
q_seq = q[start:end]
k_seq = k[start:end]
v_seq = v[start:end]
# Expand KV heads if using GQA
if H_Q != H_KV:
kv_group_num = H_Q // H_KV
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
o_ref[start:end] = ref_masked_attention(
q_seq,
k_seq,
v_seq,
is_causal=is_causal,
sliding_window_q=None,
sliding_window_k=None,
)
# Compare outputs
torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("max_seq_len", [1024])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D", [128])
@pytest.mark.parametrize("sliding_window", [(32, 32), (32, 0), (0, 32)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_context_attention_sliding_window(
B: int,
max_seq_len: int,
H_Q: int,
H_KV: int,
D: int,
sliding_window: tuple[int, int],
dtype: torch.dtype,
):
sliding_window_q, sliding_window_k = sliding_window
"""Test context attention with sliding window."""
torch.manual_seed(42)
# Generate random sequence lengths for each batch
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
total_tokens = seq_lens.sum().item()
# Create batch start locations
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
# Create input tensors
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
o = torch.zeros_like(q)
# Call Triton kernel
context_attention_fwd(
q,
k,
v,
o,
b_start_loc,
seq_lens,
max_seq_len,
is_causal=False,
sliding_window_q=sliding_window_q,
sliding_window_k=sliding_window_k,
)
# Compute reference output for each sequence in batch
o_ref = torch.zeros_like(q)
for i in range(B):
start = b_start_loc[i].item()
end = start + seq_lens[i].item()
q_seq = q[start:end]
k_seq = k[start:end]
v_seq = v[start:end]
# Expand KV heads if using GQA
if H_Q != H_KV:
kv_group_num = H_Q // H_KV
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
o_ref[start:end] = ref_masked_attention(
q_seq,
k_seq,
v_seq,
is_causal=False,
sliding_window_q=sliding_window_q if sliding_window_q > 0 else None,
sliding_window_k=sliding_window_k if sliding_window_k > 0 else None,
)
# Compare outputs
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)

View File

@@ -114,7 +114,7 @@ def check_model_available(model: str) -> None:
@pytest.mark.core_model
@pytest.mark.cpu_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("enforce_eager", [True, False])
@create_new_process_for_each_test("spawn")

View File

@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig
from vllm.platforms import current_platform
@@ -83,6 +84,13 @@ BATCH_SPECS = {
),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
# encoder-only
"small_encoder_prefill": BatchSpec(
seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256]
),
"medium_encoder_prefill": BatchSpec(
seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048]
),
}
@@ -209,6 +217,7 @@ def run_attention_backend(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: AttentionType = AttentionType.DECODER,
sliding_window: int | None = None,
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
@@ -276,6 +285,7 @@ def run_attention_backend(
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=sliding_window,
attn_type=attn_type,
kv_cache_dtype="auto",
)
@@ -299,6 +309,7 @@ def _test_backend_correctness(
backend_to_test: list[AttentionBackendEnum | str],
mask_mod,
*,
attn_type: AttentionType = AttentionType.DECODER,
block_size: int = 16,
atol: float = 1e-2,
rtol: float = 1e-2,
@@ -436,6 +447,9 @@ def _test_backend_correctness(
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
if attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only, all tokens are prefill tokens
common_attn_metadata.causal = False
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
@@ -491,6 +505,7 @@ def _test_backend_correctness(
value_vllm,
kv_cache_for_backend,
sliding_window=sliding_window,
attn_type=attn_type,
)
finally:
if reset_kv_cache_layout:
@@ -676,3 +691,45 @@ def test_sliding_window_backend_correctness(
block_size=128,
tensor_parallel_size=tensor_parallel_size,
)
@pytest.mark.parametrize(
"batch_spec_name",
[
"small_encoder_prefill",
"medium_encoder_prefill",
],
)
@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_sliding_window_encoder_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with sliding window attention."""
def bidi_sliding_window_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
sliding_window: int,
):
return torch.abs(q_idx + context_len - kv_idx) < sliding_window
batch_spec = BATCH_SPECS[batch_spec_name]
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
sliding_window = model_config.get_sliding_window()
sliding_window_mask_mod_fn = partial(
bidi_sliding_window_mask_mod, sliding_window=sliding_window
)
_test_backend_correctness(
batch_spec,
model,
SLIDING_WINDOW_BACKENDS_TO_TEST,
sliding_window_mask_mod_fn,
attn_type=AttentionType.ENCODER_ONLY,
tensor_parallel_size=tensor_parallel_size,
)

View File

@@ -0,0 +1,271 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/sgl-project/sglang/blob/97cb762bb65ebf05025eb342de03c184660427a3/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# Changes:
# - Add support for sliding window attention
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for prefill.
It supports page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
Out,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
kv_group_num: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
SLIDING_WINDOW_Q: tl.constexpr,
SLIDING_WINDOW_K: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
mask_d = offs_d < Lk
q = tl.load(
Q + off_q,
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
)
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
# Calculate the end position for attention computation
end_n = cur_batch_seq_len
# Apply causal attention pruning and sliding window attention pruning
end_n = tl.minimum(end_n, (start_m + 1) * BLOCK_M) if IS_CAUSAL else end_n
# Calculate the start position for backward sliding window
start_n_limit = 0
end_n_limit = block_mask * end_n
for start_n in range(start_n_limit, end_n_limit, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
other=0.0,
)
# Apply attention mask (causal + bidirectional sliding window)
# Position indices in the sequence
pos_q = offs_m[:, None] # Query positions [BLOCK_M, 1]
pos_k = start_n + offs_n[None, :] # Key positions [1, BLOCK_N]
# Valid sequence mask
mask = pos_k < cur_batch_seq_len
# Causal mask
if IS_CAUSAL:
mask &= pos_q >= pos_k
# Bidirectional sliding window masks
sliding_mask_q = (
pos_q - pos_k <= SLIDING_WINDOW_Q if SLIDING_WINDOW_Q > 0 else None
)
sliding_mask_k = (
pos_k - pos_q <= SLIDING_WINDOW_K if SLIDING_WINDOW_K > 0 else None
)
if sliding_mask_q is not None:
mask &= sliding_mask_q
if sliding_mask_k is not None:
mask &= sliding_mask_k
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.where(mask, 0, float("-inf"))
qk += tl.dot(q, k)
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_ij_valid_mask = m_ij > float("-inf")
m_ij_masked = tl.where(m_ij_valid_mask, m_ij, 0.0)
# -- compute p and l_ij --
p = tl.exp(qk - m_ij_masked[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
m_i_new_mask = m_i_new > float("-inf")
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
# mask alpha and beta for sliding window
alpha = tl.where(m_i_new_mask, alpha, 1.0)
beta = tl.where(m_i_new_mask, beta, 0.0)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
# For sliding window there's a chance the l_i_new is 0 due to masking
# the entire row. We need to set l_i_new 1 to avoid zero division
l_i_new_mask = (l_i_new != 0.0) & (m_i_new_mask > float("-inf"))
l_i_new_safe = tl.where(l_i_new_mask, l_i_new, 1.0)
p_scale = beta / l_i_new_safe
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new_safe * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
out_ptrs = Out + off_o
tl.store(
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
)
def get_block_size(dtype: torch.dtype) -> int:
if dtype == torch.float32:
return 32
elif (
current_platform.is_cuda_alike()
) and current_platform.get_device_capability().major > 8:
return 128
else:
return 64
def context_attention_fwd(
q,
k,
v,
o,
b_start_loc,
b_seq_len,
max_input_len,
is_causal=True,
sliding_window_q=None,
sliding_window_k=None,
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
BLOCK = get_block_size(q.dtype)
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK,
IS_CAUSAL=is_causal,
SLIDING_WINDOW_Q=sliding_window_q,
SLIDING_WINDOW_K=sliding_window_k,
num_warps=num_warps,
num_stages=1,
Lk=Lk,
)

View File

@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
@@ -289,14 +288,6 @@ class RocmPlatform(Platform):
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
# Priority 5: If model is Encoder-only self-attention type
if (
attn_selector_config.attn_type is not None
and attn_selector_config.attn_type == AttentionType.ENCODER_ONLY
):
logger.info("Using FlexAttention backend.")
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
# Default: Triton Unified Attention
logger.info("Using Triton Attention backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()

View File

@@ -13,6 +13,7 @@ from vllm.attention.backends.abstract import (
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
@@ -309,6 +310,16 @@ class TritonAttentionBackend(AttentionBackend):
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""TritonAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
@@ -341,6 +352,8 @@ class TritonAttentionImpl(AttentionImpl):
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
@@ -352,10 +365,6 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for TritonAttentionImpl"
)
self.attn_type = attn_type
self.fp8_dtype = current_platform.fp8_dtype()
@@ -417,6 +426,21 @@ class TritonAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if (
@@ -495,3 +519,48 @@ class TritonAttentionImpl(AttentionImpl):
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len
# Call flash attention directly on Q, K, V tensors
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False, # Encoder attention is bidirectional
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output