[Deepseek v3.2] Remove extra logics in indexer (#26465)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: Lain <siyuanf@nvidia.com> Co-authored-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
@@ -580,9 +580,9 @@ def sparse_attn_indexer(
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = torch.empty(
|
||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||
)
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
@@ -592,9 +592,6 @@ def sparse_attn_indexer(
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
)
|
||||
topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
|
||||
] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
@@ -628,26 +625,14 @@ def sparse_attn_indexer(
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
# padded query len
|
||||
current_device = padded_q_fp8_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
|
||||
next_n_offset = (
|
||||
torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
|
||||
% next_n
|
||||
)
|
||||
index_end_pos = (
|
||||
decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||
).unsqueeze(1)
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = torch.empty(
|
||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||
)
|
||||
torch.ops._C.top_k_per_row(
|
||||
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
||||
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
|
||||
index_end_pos.to(dtype=torch.int32, device=logits.device),
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
@@ -660,9 +645,9 @@ def sparse_attn_indexer(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens,
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices.to(dtype=torch.int32)
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices
|
||||
)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user