[Attention][1/n] Remove usage of deprecated seq_lens_cpu and num_computed_tokens_cpu CommonAttentionMetadata properties (#31773)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-06 07:05:17 -05:00
committed by GitHub
parent 14df02b4e1
commit e0327c9db2
9 changed files with 23 additions and 12 deletions

View File

@@ -870,7 +870,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Guard access to seq_lens_cpu, which may not always be needed
# and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
seq_lens_cpu = (
common_attn_metadata.seq_lens.cpu() if needs_seq_lens_cpu else None
)
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = (
(seq_lens_np + (page_size - 1)) // page_size

View File

@@ -727,9 +727,7 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
block_table_tensor, seq_lens, block_size, num_gpu_blocks
)
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
self.device, non_blocking=True
)
offset_tensor = common_attn_metadata.compute_num_computed_tokens()
out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,

View File

@@ -791,7 +791,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata = None
if num_prefills > 0:
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
reqs_start = num_decodes # prefill_start

View File

@@ -511,7 +511,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

View File

@@ -221,7 +221,7 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None

View File

@@ -100,6 +100,8 @@ class CommonAttentionMetadata:
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
_num_computed_tokens_cache: torch.Tensor | None = None
@property
@deprecated(
"""
@@ -130,6 +132,13 @@ class CommonAttentionMetadata:
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu
def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache
# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int