[Feature]: Remove DtoH Copy for lfm2_vl On Default Stream (#32815)
Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
This commit is contained in:
committed by
GitHub
parent
10e94c84f6
commit
13d8746c54
@@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
query_start_loc_cpu = m.query_start_loc_cpu
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
spec_sequence_masks_cpu: torch.Tensor | None = None
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
@@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks_cpu.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
spec_sequence_masks_cpu = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
spec_sequence_masks = spec_sequence_masks_cpu.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
@@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc_cpu = query_start_loc_cpu
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
assert spec_sequence_masks_cpu is not None
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
@@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
non_spec_query_start_loc_cpu = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
@@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
non_spec_query_start_loc_cpu = torch.zeros(
|
||||
query_lens_cpu.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens_cpu[~spec_sequence_masks_cpu],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc_cpu[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
@@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
assert non_spec_query_start_loc_cpu is not None
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
compute_causal_conv1d_metadata(
|
||||
non_spec_query_start_loc_cpu,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
|
||||
@@ -219,21 +219,24 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
if num_prefills > 0:
|
||||
if num_computed_tokens is None:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
num_computed_tokens_cpu = num_computed_tokens.cpu()
|
||||
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
has_initial_states_p = (
|
||||
num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
compute_causal_conv1d_metadata(
|
||||
query_start_loc_p_cpu,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
)
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
|
||||
@@ -732,13 +732,17 @@ def create_fast_prefill_custom_backend(
|
||||
return attn_backend
|
||||
|
||||
|
||||
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||
# Needed for causal_conv1d
|
||||
seqlens = query_start_loc_p.diff().to("cpu")
|
||||
def compute_causal_conv1d_metadata(
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
*,
|
||||
device: torch.device,
|
||||
):
|
||||
# Needed for causal_conv1d. Use the CPU query_start_loc to avoid DtoH sync.
|
||||
assert query_start_loc_p_cpu.device.type == "cpu"
|
||||
seqlens = query_start_loc_p_cpu.diff()
|
||||
nums_dict = {} # type: ignore
|
||||
batch_ptr = None
|
||||
token_chunk_offset_ptr = None
|
||||
device = query_start_loc_p.device
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
|
||||
Reference in New Issue
Block a user