[Misc] Add max_seq_len to CommonAttentionMetadata (#23216)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder(
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
@@ -463,7 +463,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
page_size = self.page_size
|
||||
max_q_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
@@ -305,7 +305,7 @@ class FlexAttentionMetadataBuilder(
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
@@ -270,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
@@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder(
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder(
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
@@ -58,6 +58,8 @@ class CommonAttentionMetadata:
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
max_seq_len: int
|
||||
"""Longest context length in batch"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
@@ -107,6 +109,7 @@ def _make_metadata_with_slice(
|
||||
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
||||
request_slice]
|
||||
|
||||
@@ -128,6 +131,7 @@ def _make_metadata_with_slice(
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
@@ -520,6 +524,7 @@ def make_local_attention_virtual_batches(
|
||||
|
||||
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
@@ -531,6 +536,7 @@ def make_local_attention_virtual_batches(
|
||||
num_reqs=len(seq_lens_cpu),
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
max_query_len=seqlens_q_local.max(),
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_local,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
|
||||
@@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder(
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
|
||||
Reference in New Issue
Block a user