[v1] Add PrefixLM support to TritonAttention backend (#30386)
This commit is contained in:
@@ -19,7 +19,6 @@ from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Gemma3TextConfig
|
||||
|
||||
@@ -226,77 +225,9 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
if not kwargs.get("has_images", False):
|
||||
# Fast path for text-only inputs. The performance for the text-only
|
||||
# inputs are not affected by the naive attention below.
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
||||
# that correspond to the same image while using causal attention
|
||||
# otherwise. Current attention backends cannot handle this pattern, so
|
||||
# we temporarily use a naive attention implementation with mask tensors.
|
||||
|
||||
# We intentionally keep the attention backend as-is and only override
|
||||
# `attn_output` with the naive implementation's output. This minimizes
|
||||
# changes to existing model runners and attention backends. The call to
|
||||
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
||||
# output is discarded and overwritten below. While this duplicates
|
||||
# computation, it maintains compatibility.
|
||||
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
||||
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def naive_attn_with_masks(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): As described in the comment above, this code is not
|
||||
# meant to be performant. It is only meant to be correct.
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
# Expand the key and value to handle GQA.
|
||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
|
||||
if self.is_sliding:
|
||||
attn_masks = kwargs["local_attn_masks"]
|
||||
else:
|
||||
attn_masks = kwargs["global_attn_masks"]
|
||||
|
||||
seq_lens = kwargs["seq_lens"]
|
||||
start_idx = 0
|
||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
||||
end_idx = start_idx + seq_len
|
||||
query = q[start_idx:end_idx].unsqueeze(0)
|
||||
key = k[start_idx:end_idx].unsqueeze(0)
|
||||
value = v[start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
# Transpose.
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
self.scaling,
|
||||
)
|
||||
output = output.transpose(1, 2).flatten(-2, -1)
|
||||
out[start_idx:end_idx] = output
|
||||
start_idx = end_idx
|
||||
return out
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user