Add topk logits torch op for DS3.2. (#25945)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -643,17 +643,24 @@ def sparse_attn_indexer(
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1]
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = (
|
||||
topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = torch.empty(
|
||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||
)
|
||||
mask = torch.full_like(
|
||||
topk_indices, False, dtype=torch.bool, device=topk_indices.device
|
||||
topk_values = torch.empty(
|
||||
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
||||
)
|
||||
torch.ops._C.top_k_per_row(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
topk_values,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||
topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
|
||||
] = topk_indices.to(dtype=torch.int32)
|
||||
@@ -693,28 +700,32 @@ def sparse_attn_indexer(
|
||||
# padded query len
|
||||
current_device = padded_q_fp8_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
positions = (
|
||||
torch.arange(max_model_len, device=current_device)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size * next_n, -1)
|
||||
)
|
||||
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
|
||||
next_n_offset = (
|
||||
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
|
||||
% next_n
|
||||
)
|
||||
index_end_pos = (
|
||||
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset
|
||||
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||
).unsqueeze(1)
|
||||
# index_end_pos: [B * N, 1]
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
topk_indices[topk_indices > index_end_pos] = -1
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = torch.empty(
|
||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||
)
|
||||
topk_values = torch.empty(
|
||||
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
||||
)
|
||||
torch.ops._C.top_k_per_row(
|
||||
logits,
|
||||
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
|
||||
index_end_pos.to(dtype=torch.int32, device=logits.device),
|
||||
topk_indices,
|
||||
topk_values,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
)
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
|
||||
Reference in New Issue
Block a user