[Attention][2/n] Remove usage of deprecated seq_lens_cpu and num_computed_tokens_cpu CommonAttentionMetadata properties (#31774)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -142,8 +142,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device, non_blocking=True)
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
@@ -370,6 +369,5 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||
|
||||
@@ -215,7 +215,10 @@ class Mamba2AttentionMetadataBuilder(
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
|
||||
@@ -138,9 +138,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
mamba_block_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
# Block index of the last computed token
|
||||
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||
# which is <= block index for the first scheduled token
|
||||
@@ -193,13 +191,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
@@ -212,15 +209,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
if num_prefills > 0:
|
||||
if num_computed_tokens is None:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
num_computed_tokens_cpu = num_computed_tokens.cpu()
|
||||
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
|
||||
Reference in New Issue
Block a user