[Misc] Add max_seq_len to CommonAttentionMetadata (#23216)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-20 09:05:29 -07:00
committed by GitHub
parent 5efd6905bc
commit d6d13bd49e
12 changed files with 22 additions and 7 deletions

View File

@@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder(
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu

View File

@@ -463,7 +463,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()
max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor

View File

@@ -305,7 +305,7 @@ class FlexAttentionMetadataBuilder(
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor

View File

@@ -270,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder(
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor

View File

@@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder(
q_start_loc = common_attn_metadata.query_start_loc
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping

View File

@@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder(
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor

View File

@@ -58,6 +58,8 @@ class CommonAttentionMetadata:
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length in batch"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
@@ -107,6 +109,7 @@ def _make_metadata_with_slice(
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
request_slice]
@@ -128,6 +131,7 @@ def _make_metadata_with_slice(
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
)
@@ -520,6 +524,7 @@ def make_local_attention_virtual_batches(
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
max_seq_len = int(seq_lens_cpu.max())
return CommonAttentionMetadata(
query_start_loc_cpu=query_start_loc_cpu,
@@ -531,6 +536,7 @@ def make_local_attention_virtual_batches(
num_reqs=len(seq_lens_cpu),
num_actual_tokens=common_attn_metadata.num_actual_tokens,
max_query_len=seqlens_q_local.max(),
max_seq_len=max_seq_len,
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,

View File

@@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder(
q_seqlens = torch.diff(q_start_loc)
max_query_len = common_attn_metadata.max_query_len
kv_seqlens = common_attn_metadata.seq_lens
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
max_seq_len = common_attn_metadata.max_seq_len
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping