[Attention][1/n] Remove usage of deprecated seq_lens_cpu and num_computed_tokens_cpu CommonAttentionMetadata properties (#31773)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -870,7 +870,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# Guard access to seq_lens_cpu, which may not always be needed
|
||||
# and can be expensive to retrieve in async mode.
|
||||
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
|
||||
seq_lens_cpu = (
|
||||
common_attn_metadata.seq_lens.cpu() if needs_seq_lens_cpu else None
|
||||
)
|
||||
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
|
||||
num_blocks_np = (
|
||||
(seq_lens_np + (page_size - 1)) // page_size
|
||||
|
||||
@@ -727,9 +727,7 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
||||
block_table_tensor, seq_lens, block_size, num_gpu_blocks
|
||||
)
|
||||
|
||||
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
offset_tensor = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
causal=common_attn_metadata.causal,
|
||||
|
||||
@@ -791,7 +791,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
|
||||
@@ -511,7 +511,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
# For pure decode batches, prefill_request_id will be None
|
||||
# For mixed batches, it will have -1 for decode and request_id for prefill
|
||||
if num_prefills > 0:
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens.cpu()
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens.cpu() - common_prefix_len
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
|
||||
@@ -100,6 +100,8 @@ class CommonAttentionMetadata:
|
||||
_seq_lens_cpu: torch.Tensor | None = None
|
||||
_num_computed_tokens_cpu: torch.Tensor | None = None
|
||||
|
||||
_num_computed_tokens_cache: torch.Tensor | None = None
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
@@ -130,6 +132,13 @@ class CommonAttentionMetadata:
|
||||
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
|
||||
return self._num_computed_tokens_cpu
|
||||
|
||||
def compute_num_computed_tokens(self) -> torch.Tensor:
|
||||
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
|
||||
if self._num_computed_tokens_cache is None:
|
||||
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
|
||||
self._num_computed_tokens_cache = self.seq_lens - query_lens
|
||||
return self._num_computed_tokens_cache
|
||||
|
||||
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
||||
def unpadded(
|
||||
self, num_actual_tokens: int, num_actual_reqs: int
|
||||
|
||||
Reference in New Issue
Block a user