Add topk logits torch op for DS3.2. (#25945)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Daniel Cámpora
2025-10-07 12:07:32 +02:00
committed by GitHub
parent d100d78eb3
commit e1098ced95
5 changed files with 446 additions and 25 deletions

View File

@@ -643,17 +643,24 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = (
topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
)
mask = torch.full_like(
topk_indices, False, dtype=torch.bool, device=topk_indices.device
topk_values = torch.empty(
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
)
torch.ops._C.top_k_per_row(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
topk_values,
num_rows,
logits.stride(0),
logits.stride(1),
)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
] = topk_indices.to(dtype=torch.int32)
@@ -693,28 +700,32 @@ def sparse_attn_indexer(
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
positions = (
torch.arange(max_model_len, device=current_device)
.unsqueeze(0)
.expand(batch_size * next_n, -1)
)
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
next_n_offset = (
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
% next_n
)
index_end_pos = (
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float("-inf"))
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
)
topk_values = torch.empty(
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
)
torch.ops._C.top_k_per_row(
logits,
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
index_end_pos.to(dtype=torch.int32, device=logits.device),
topk_indices,
topk_values,
num_rows,
logits.stride(0),
logits.stride(1),
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens