[GDN] Use CPU tensors to build GDN metadata (#34498)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -206,21 +206,24 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
assert spec_sequence_masks_cpu is not None
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
# Use CPU tensors to avoid CPU-GPU sync
|
||||
non_spec_query_lens_cpu = query_lens_cpu[~spec_sequence_masks_cpu]
|
||||
num_decodes = (non_spec_query_lens_cpu == 1).sum().item()
|
||||
# Exclude zero-length padded sequences from prefill count.
|
||||
num_zero_len = (non_spec_query_lens == 0).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes - num_zero_len
|
||||
num_zero_len = (non_spec_query_lens_cpu == 0).sum().item()
|
||||
num_prefills = non_spec_query_lens_cpu.size(0) - num_decodes - num_zero_len
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_prefill_tokens = (
|
||||
non_spec_query_lens_cpu.sum().item() - num_decode_tokens
|
||||
)
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
query_lens_cpu.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
query_start_loc_cpu[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
|
||||
@@ -775,10 +775,10 @@ def compute_causal_conv1d_metadata(
|
||||
MAX_NUM_PROGRAMS
|
||||
).fill_(PAD_SLOT_ID)
|
||||
|
||||
batch_ptr[0:mlist_len].copy_(mlist)
|
||||
batch_ptr[0:mlist_len].copy_(mlist, non_blocking=True)
|
||||
token_chunk_offset_ptr[ # type: ignore
|
||||
0:mlist_len
|
||||
].copy_(offsetlist)
|
||||
].copy_(offsetlist, non_blocking=True)
|
||||
nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
|
||||
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
|
||||
|
||||
|
||||
Reference in New Issue
Block a user