[Attention][2/n] Remove usage of deprecated seq_lens_cpu and num_computed_tokens_cpu CommonAttentionMetadata properties (#31774)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Lucas Wilkinson
2026-01-06 12:32:14 -05:00
committed by GitHub
parent 2f4bdee61e
commit 4c73be14e0
3 changed files with 13 additions and 14 deletions

View File

@@ -142,8 +142,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (
@@ -370,6 +369,5 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)

View File

@@ -215,7 +215,10 @@ class Mamba2AttentionMetadataBuilder(
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (

View File

@@ -138,9 +138,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
common_attn_metadata: CommonAttentionMetadata,
mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Block index of the last computed token
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
# which is <= block index for the first scheduled token
@@ -193,13 +191,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
@@ -212,15 +209,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
if num_prefills > 0:
if num_computed_tokens is None:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
num_computed_tokens_cpu = num_computed_tokens.cpu()
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device