diff --git a/tests/kernels/attention/test_triton_prefill_attention.py b/tests/kernels/attention/test_triton_prefill_attention.py new file mode 100644 index 000000000..67c52cbfd --- /dev/null +++ b/tests/kernels/attention/test_triton_prefill_attention.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +from vllm.attention.ops.triton_prefill_attention import context_attention_fwd + + +def ref_masked_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_causal: bool = True, + sliding_window_q: int | None = None, + sliding_window_k: int | None = None, +) -> torch.Tensor: + """Reference implementation using PyTorch SDPA.""" + # q, k, v: [total_tokens, num_heads, head_dim] + # SDPA expects [batch, num_heads, seq_len, head_dim] + + total_tokens = q.shape[0] + + # Add batch dimension and transpose + q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim] + k = k.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim] + v = v.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim] + + # Create attention mask if needed + attn_mask = None + use_causal = is_causal + + # If we have sliding window or need custom masking, create explicit mask + sliding_window_q = sliding_window_q if sliding_window_q is not None else 0 + sliding_window_k = sliding_window_k if sliding_window_k is not None else 0 + if (sliding_window_q > 0) or (sliding_window_k > 0): + # Position indices + pos_q = torch.arange(total_tokens, device=q.device).unsqueeze(1) + pos_k = torch.arange(total_tokens, device=q.device).unsqueeze(0) + + # Start with valid mask (False = no masking) + mask = torch.ones( + (total_tokens, total_tokens), dtype=torch.bool, device=q.device + ) + + # Apply causal mask + if is_causal: + mask = mask & (pos_q >= pos_k) + + # Apply sliding window masks + sliding_window_mask = torch.ones_like(mask) + if sliding_window_q > 0: + sliding_window_mask &= pos_q - pos_k <= sliding_window_q + + if sliding_window_k > 0: + sliding_window_mask &= pos_k - pos_q <= sliding_window_k + + mask = mask & sliding_window_mask + + attn_mask = torch.where(mask, 0.0, float("-inf")).to(q.dtype) + use_causal = False # Don't use is_causal when providing explicit mask + + # Use SDPA + output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=use_causal, dropout_p=0.0 + ) + + # Convert back to original shape: [total_tokens, num_heads, head_dim] + output = output.transpose(1, 2).squeeze(0) + + return output + + +@pytest.mark.parametrize("B", [5]) +@pytest.mark.parametrize("max_seq_len", [1024]) +@pytest.mark.parametrize("H_Q", [32]) +@pytest.mark.parametrize("H_KV", [32, 8]) +@pytest.mark.parametrize("D", [128]) +@pytest.mark.parametrize("is_causal", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_context_attention( + B: int, + max_seq_len: int, + H_Q: int, + H_KV: int, + D: int, + is_causal: bool, + dtype: torch.dtype, +): + """Test basic context attention without sliding window.""" + torch.manual_seed(42) + + # Generate random sequence lengths for each batch + seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda") + total_tokens = seq_lens.sum().item() + + # Create batch start locations + b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) + + # Create input tensors + q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda") + k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + o = torch.zeros_like(q) + + # Call Triton kernel + context_attention_fwd( + q, + k, + v, + o, + b_start_loc, + seq_lens, + max_seq_len, + is_causal=is_causal, + sliding_window_q=None, + sliding_window_k=None, + ) + + # Compute reference output for each sequence in batch + o_ref = torch.zeros_like(q) + for i in range(B): + start = b_start_loc[i].item() + end = start + seq_lens[i].item() + + q_seq = q[start:end] + k_seq = k[start:end] + v_seq = v[start:end] + + # Expand KV heads if using GQA + if H_Q != H_KV: + kv_group_num = H_Q // H_KV + k_seq = k_seq.repeat_interleave(kv_group_num, dim=1) + v_seq = v_seq.repeat_interleave(kv_group_num, dim=1) + + o_ref[start:end] = ref_masked_attention( + q_seq, + k_seq, + v_seq, + is_causal=is_causal, + sliding_window_q=None, + sliding_window_k=None, + ) + + # Compare outputs + torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("B", [4]) +@pytest.mark.parametrize("max_seq_len", [1024]) +@pytest.mark.parametrize("H_Q", [32]) +@pytest.mark.parametrize("H_KV", [32, 8]) +@pytest.mark.parametrize("D", [128]) +@pytest.mark.parametrize("sliding_window", [(32, 32), (32, 0), (0, 32)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_context_attention_sliding_window( + B: int, + max_seq_len: int, + H_Q: int, + H_KV: int, + D: int, + sliding_window: tuple[int, int], + dtype: torch.dtype, +): + sliding_window_q, sliding_window_k = sliding_window + """Test context attention with sliding window.""" + torch.manual_seed(42) + + # Generate random sequence lengths for each batch + seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda") + total_tokens = seq_lens.sum().item() + + # Create batch start locations + b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) + + # Create input tensors + q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda") + k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + o = torch.zeros_like(q) + + # Call Triton kernel + context_attention_fwd( + q, + k, + v, + o, + b_start_loc, + seq_lens, + max_seq_len, + is_causal=False, + sliding_window_q=sliding_window_q, + sliding_window_k=sliding_window_k, + ) + + # Compute reference output for each sequence in batch + o_ref = torch.zeros_like(q) + for i in range(B): + start = b_start_loc[i].item() + end = start + seq_lens[i].item() + + q_seq = q[start:end] + k_seq = k[start:end] + v_seq = v[start:end] + + # Expand KV heads if using GQA + if H_Q != H_KV: + kv_group_num = H_Q // H_KV + k_seq = k_seq.repeat_interleave(kv_group_num, dim=1) + v_seq = v_seq.repeat_interleave(kv_group_num, dim=1) + + o_ref[start:end] = ref_masked_attention( + q_seq, + k_seq, + v_seq, + is_causal=False, + sliding_window_q=sliding_window_q if sliding_window_q > 0 else None, + sliding_window_k=sliding_window_k if sliding_window_k > 0 else None, + ) + + # Compare outputs + torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index b206995a9..23459963f 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -114,7 +114,7 @@ def check_model_available(model: str) -> None: @pytest.mark.core_model @pytest.mark.cpu_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("enforce_eager", [True, False]) @create_new_process_for_each_test("spawn") diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6e08b9316..5495b4fc1 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -15,6 +15,7 @@ from tests.v1.attention.utils import ( create_vllm_config, try_get_attention_backend, ) +from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig from vllm.platforms import current_platform @@ -83,6 +84,13 @@ BATCH_SPECS = { ), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), + # encoder-only + "small_encoder_prefill": BatchSpec( + seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256] + ), + "medium_encoder_prefill": BatchSpec( + seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048] + ), } @@ -209,6 +217,7 @@ def run_attention_backend( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, + attn_type: AttentionType = AttentionType.DECODER, sliding_window: int | None = None, ) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" @@ -276,6 +285,7 @@ def run_attention_backend( num_kv_heads=num_kv_heads, alibi_slopes=None, sliding_window=sliding_window, + attn_type=attn_type, kv_cache_dtype="auto", ) @@ -299,6 +309,7 @@ def _test_backend_correctness( backend_to_test: list[AttentionBackendEnum | str], mask_mod, *, + attn_type: AttentionType = AttentionType.DECODER, block_size: int = 16, atol: float = 1e-2, rtol: float = 1e-2, @@ -436,6 +447,9 @@ def _test_backend_correctness( common_attn_metadata = create_common_attn_metadata( batch_spec, vllm_config.cache_config.block_size, device ) + if attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only, all tokens are prefill tokens + common_attn_metadata.causal = False # 3. Simulate Paged KV Cache and a realistic slot_mapping kv_cache = create_and_prepopulate_kv_cache( @@ -491,6 +505,7 @@ def _test_backend_correctness( value_vllm, kv_cache_for_backend, sliding_window=sliding_window, + attn_type=attn_type, ) finally: if reset_kv_cache_layout: @@ -676,3 +691,45 @@ def test_sliding_window_backend_correctness( block_size=128, tensor_parallel_size=tensor_parallel_size, ) + + +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_encoder_prefill", + "medium_encoder_prefill", + ], +) +@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) +def test_sliding_window_encoder_backend_correctness( + batch_spec_name: str, model: str, tensor_parallel_size: int +): + """Test backend's correctness with sliding window attention.""" + + def bidi_sliding_window_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + sliding_window: int, + ): + return torch.abs(q_idx + context_len - kv_idx) < sliding_window + + batch_spec = BATCH_SPECS[batch_spec_name] + model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens)) + sliding_window = model_config.get_sliding_window() + sliding_window_mask_mod_fn = partial( + bidi_sliding_window_mask_mod, sliding_window=sliding_window + ) + + _test_backend_correctness( + batch_spec, + model, + SLIDING_WINDOW_BACKENDS_TO_TEST, + sliding_window_mask_mod_fn, + attn_type=AttentionType.ENCODER_ONLY, + tensor_parallel_size=tensor_parallel_size, + ) diff --git a/vllm/attention/ops/triton_prefill_attention.py b/vllm/attention/ops/triton_prefill_attention.py new file mode 100644 index 000000000..ae7332830 --- /dev/null +++ b/vllm/attention/ops/triton_prefill_attention.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/sgl-project/sglang/blob/97cb762bb65ebf05025eb342de03c184660427a3/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +# Changes: +# - Add support for sliding window attention + +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for prefill. +It supports page size = 1. +""" + +# Adapted from +# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + SLIDING_WINDOW_Q: tl.constexpr, + SLIDING_WINDOW_K: tl.constexpr, + Lk: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] + + mask_d = offs_d < Lk + + q = tl.load( + Q + off_q, + mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + # Calculate the end position for attention computation + end_n = cur_batch_seq_len + + # Apply causal attention pruning and sliding window attention pruning + end_n = tl.minimum(end_n, (start_m + 1) * BLOCK_M) if IS_CAUSAL else end_n + + # Calculate the start position for backward sliding window + start_n_limit = 0 + end_n_limit = block_mask * end_n + + for start_n in range(start_n_limit, end_n_limit, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), + other=0.0, + ) + + # Apply attention mask (causal + bidirectional sliding window) + # Position indices in the sequence + pos_q = offs_m[:, None] # Query positions [BLOCK_M, 1] + pos_k = start_n + offs_n[None, :] # Key positions [1, BLOCK_N] + + # Valid sequence mask + mask = pos_k < cur_batch_seq_len + # Causal mask + if IS_CAUSAL: + mask &= pos_q >= pos_k + + # Bidirectional sliding window masks + sliding_mask_q = ( + pos_q - pos_k <= SLIDING_WINDOW_Q if SLIDING_WINDOW_Q > 0 else None + ) + sliding_mask_k = ( + pos_k - pos_q <= SLIDING_WINDOW_K if SLIDING_WINDOW_K > 0 else None + ) + if sliding_mask_q is not None: + mask &= sliding_mask_q + if sliding_mask_k is not None: + mask &= sliding_mask_k + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.where(mask, 0, float("-inf")) + qk += tl.dot(q, k) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_ij_valid_mask = m_ij > float("-inf") + m_ij_masked = tl.where(m_ij_valid_mask, m_ij, 0.0) + # -- compute p and l_ij -- + p = tl.exp(qk - m_ij_masked[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + m_i_new_mask = m_i_new > float("-inf") + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + # mask alpha and beta for sliding window + alpha = tl.where(m_i_new_mask, alpha, 1.0) + beta = tl.where(m_i_new_mask, beta, 0.0) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + # For sliding window there's a chance the l_i_new is 0 due to masking + # the entire row. We need to set l_i_new 1 to avoid zero division + l_i_new_mask = (l_i_new != 0.0) & (m_i_new_mask > float("-inf")) + l_i_new_safe = tl.where(l_i_new_mask, l_i_new, 1.0) + p_scale = beta / l_i_new_safe + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new_safe * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] + ) + out_ptrs = Out + off_o + tl.store( + out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) + ) + + +def get_block_size(dtype: torch.dtype) -> int: + if dtype == torch.float32: + return 32 + elif ( + current_platform.is_cuda_alike() + ) and current_platform.get_device_capability().major > 8: + return 128 + else: + return 64 + + +def context_attention_fwd( + q, + k, + v, + o, + b_start_loc, + b_seq_len, + max_input_len, + is_causal=True, + sliding_window_q=None, + sliding_window_k=None, +): + """ + q, k, v: [b * s, head, head_dim] + b_start_loc: [b] + b_seq_len: [b] + out: [b * s, head, head_dim] + """ + BLOCK = get_block_size(q.dtype) + + Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + num_warps = 4 if Lk <= 64 else 8 + + sliding_window_q = sliding_window_q if sliding_window_q is not None else 0 + sliding_window_k = sliding_window_k if sliding_window_k is not None else 0 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=triton.next_power_of_2(Lk), + BLOCK_N=BLOCK, + IS_CAUSAL=is_causal, + SLIDING_WINDOW_Q=sliding_window_q, + SLIDING_WINDOW_K=sliding_window_k, + num_warps=num_warps, + num_stages=1, + Lk=Lk, + ) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 785e457fc..278be5a71 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Optional import torch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils.torch_utils import cuda_device_count_stateless @@ -289,14 +288,6 @@ class RocmPlatform(Platform): logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() - # Priority 5: If model is Encoder-only self-attention type - if ( - attn_selector_config.attn_type is not None - and attn_selector_config.attn_type == AttentionType.ENCODER_ONLY - ): - logger.info("Using FlexAttention backend.") - return AttentionBackendEnum.FLEX_ATTENTION.get_path() - # Default: Triton Unified Attention logger.info("Using Triton Attention backend.") return AttentionBackendEnum.TRITON_ATTN.get_path() diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ca7be990c..8cf363d59 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -13,6 +13,7 @@ from vllm.attention.backends.abstract import ( AttentionType, MultipleOf, ) +from vllm.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, ) @@ -309,6 +310,16 @@ class TritonAttentionBackend(AttentionBackend): def supports_sink(cls) -> bool: return True + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """TritonAttention supports all attention types.""" + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return True @@ -341,6 +352,8 @@ class TritonAttentionImpl(AttentionImpl): self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) + elif attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): + self.sliding_window = (sliding_window - 1, sliding_window - 1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype @@ -352,10 +365,6 @@ class TritonAttentionImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: - raise NotImplementedError( - "Encoder self-attention is not implemented for TritonAttentionImpl" - ) self.attn_type = attn_type self.fp8_dtype = current_platform.fp8_dtype() @@ -417,6 +426,21 @@ class TritonAttentionImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + + # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(1) if ( @@ -495,3 +519,48 @@ class TritonAttentionImpl(AttentionImpl): ) return output + + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + layer: torch.nn.Module, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "quantization is not supported for encoder attention" + ) + + # Use encoder-specific metadata for sequence information + query_start_loc = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + + # Call flash attention directly on Q, K, V tensors + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_input_len=max_query_len, + is_causal=False, # Encoder attention is bidirectional + sliding_window_q=self.sliding_window[0], + sliding_window_k=self.sliding_window[1], + ) + return output