[perf] Speed up align sum kernels (#21079)

Signed-off-by: Himanshu Jaju <hj@mistral.ai>
This commit is contained in:
Himanshu Jaju
2025-07-21 19:19:23 +01:00
committed by GitHub
parent 005ae9be6c
commit 0ec82edda5
3 changed files with 60 additions and 25 deletions

View File

@@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
sorted_ids_triton = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
)
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
expert_ids_triton = torch.zeros(
expert_ids_triton = torch.empty(
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
)
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
expert_ids_vllm = torch.empty_like(expert_ids_triton)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
# 2. run implementations
@@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")