[v1] Add PrefixLM support to TritonAttention backend (#30386)
This commit is contained in:
@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
|
||||
@property
|
||||
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||
|
||||
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||
"""
|
||||
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||
# preparation to avoid duplicate computation.
|
||||
if self.mm_prefix_range is None:
|
||||
return None
|
||||
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
device = self.seq_lens.device
|
||||
|
||||
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||
range_lists = [
|
||||
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||
]
|
||||
|
||||
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||
if all(r == [(0, 0)] for r in range_lists):
|
||||
return None
|
||||
|
||||
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||
range_tensors = [
|
||||
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||
for r in range_lists
|
||||
]
|
||||
|
||||
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
mm_prefix_range=mm_prefix_range_tensor,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user