[Feature]: Remove DtoH Copy for lfm2_vl On Default Stream (#32815)

Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
This commit is contained in:
tianshu-Michael-yu
2026-01-23 05:20:30 -08:00
committed by GitHub
parent 10e94c84f6
commit 13d8746c54
5 changed files with 260 additions and 158 deletions

View File

@@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
m = common_attn_metadata
query_start_loc = m.query_start_loc
query_start_loc_cpu = m.query_start_loc_cpu
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
spec_sequence_masks_cpu: torch.Tensor | None = None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
@@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks_cpu.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
spec_sequence_masks_cpu = None
else:
spec_sequence_masks = spec_sequence_masks.to(
spec_sequence_masks = spec_sequence_masks_cpu.to(
query_start_loc.device, non_blocking=True
)
@@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
non_spec_query_start_loc_cpu = query_start_loc_cpu
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
assert spec_sequence_masks_cpu is not None
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
@@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
non_spec_query_start_loc_cpu = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
@@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
dim=0,
out=non_spec_query_start_loc[1:],
)
non_spec_query_start_loc_cpu = torch.zeros(
query_lens_cpu.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
)
torch.cumsum(
query_lens_cpu[~spec_sequence_masks_cpu],
dim=0,
out=non_spec_query_start_loc_cpu[1:],
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
@@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
assert non_spec_query_start_loc_cpu is not None
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(non_spec_query_start_loc)
compute_causal_conv1d_metadata(
non_spec_query_start_loc_cpu,
device=query_start_loc.device,
)
)
else:
has_initial_state = None