[v1] Add PrefixLM support to TritonAttention backend (#30386)

This commit is contained in:
Isotr0py
2025-12-18 08:05:24 +08:00
committed by GitHub
parent 05a83dc6ee
commit 74a1ac38b0
4 changed files with 280 additions and 123 deletions

View File

@@ -19,7 +19,6 @@ from collections.abc import Iterable
from itertools import islice
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig
@@ -226,77 +225,9 @@ class Gemma3Attention(nn.Module):
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if not kwargs.get("has_images", False):
# Fast path for text-only inputs. The performance for the text-only
# inputs are not affected by the naive attention below.
output, _ = self.o_proj(attn_output)
return output
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
# that correspond to the same image while using causal attention
# otherwise. Current attention backends cannot handle this pattern, so
# we temporarily use a naive attention implementation with mask tensors.
# We intentionally keep the attention backend as-is and only override
# `attn_output` with the naive implementation's output. This minimizes
# changes to existing model runners and attention backends. The call to
# `self.attn(q, k, v)` is only used to populate the KV cache - its
# output is discarded and overwritten below. While this duplicates
# computation, it maintains compatibility.
# TODO(woosuk): Optimize by implementing custom attention kernels.
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
output, _ = self.o_proj(attn_output)
return output
def naive_attn_with_masks(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# NOTE(woosuk): As described in the comment above, this code is not
# meant to be performant. It is only meant to be correct.
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
if self.is_sliding:
attn_masks = kwargs["local_attn_masks"]
else:
attn_masks = kwargs["global_attn_masks"]
seq_lens = kwargs["seq_lens"]
start_idx = 0
for seq_len, attn_mask in zip(seq_lens, attn_masks):
end_idx = start_idx + seq_len
query = q[start_idx:end_idx].unsqueeze(0)
key = k[start_idx:end_idx].unsqueeze(0)
value = v[start_idx:end_idx].unsqueeze(0)
# Transpose.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
self.scaling,
)
output = output.transpose(1, 2).flatten(-2, -1)
out[start_idx:end_idx] = output
start_idx = end_idx
return out
class Gemma3DecoderLayer(nn.Module):
def __init__(