[GDN] Use CPU tensors to build GDN metadata (#34498)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-13 01:24:45 -08:00
committed by GitHub
parent 3d2a026fd0
commit 0916e7960b
2 changed files with 12 additions and 9 deletions

View File

@@ -206,21 +206,24 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
assert spec_sequence_masks_cpu is not None
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
# Use CPU tensors to avoid CPU-GPU sync
non_spec_query_lens_cpu = query_lens_cpu[~spec_sequence_masks_cpu]
num_decodes = (non_spec_query_lens_cpu == 1).sum().item()
# Exclude zero-length padded sequences from prefill count.
num_zero_len = (non_spec_query_lens == 0).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes - num_zero_len
num_zero_len = (non_spec_query_lens_cpu == 0).sum().item()
num_prefills = non_spec_query_lens_cpu.size(0) - num_decodes - num_zero_len
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
num_prefill_tokens = (
non_spec_query_lens_cpu.sum().item() - num_decode_tokens
)
num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
query_lens_cpu.sum().item() - num_prefill_tokens - num_decode_tokens
)
if num_prefills == 0 and num_decodes == 0:
spec_token_size = min(
num_spec_decodes * (self.num_spec + 1),
query_start_loc[-1].item(),
query_start_loc_cpu[-1].item(),
)
spec_token_indx = torch.arange(
spec_token_size,

View File

@@ -775,10 +775,10 @@ def compute_causal_conv1d_metadata(
MAX_NUM_PROGRAMS
).fill_(PAD_SLOT_ID)
batch_ptr[0:mlist_len].copy_(mlist)
batch_ptr[0:mlist_len].copy_(mlist, non_blocking=True)
token_chunk_offset_ptr[ # type: ignore
0:mlist_len
].copy_(offsetlist)
].copy_(offsetlist, non_blocking=True)
nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore