[Feature]: Remove DtoH Copy for lfm2_vl On Default Stream (#32815)
Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
This commit is contained in:
committed by
GitHub
parent
10e94c84f6
commit
13d8746c54
@@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
query_start_loc_cpu = m.query_start_loc_cpu
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
spec_sequence_masks_cpu: torch.Tensor | None = None
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
@@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks_cpu.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
spec_sequence_masks_cpu = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
spec_sequence_masks = spec_sequence_masks_cpu.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
@@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc_cpu = query_start_loc_cpu
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
assert spec_sequence_masks_cpu is not None
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
@@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
non_spec_query_start_loc_cpu = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
@@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
non_spec_query_start_loc_cpu = torch.zeros(
|
||||
query_lens_cpu.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens_cpu[~spec_sequence_masks_cpu],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc_cpu[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
@@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
assert non_spec_query_start_loc_cpu is not None
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
compute_causal_conv1d_metadata(
|
||||
non_spec_query_start_loc_cpu,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
|
||||
Reference in New Issue
Block a user