[perf] Speed up align sum kernels (#21079)
Signed-off-by: Himanshu Jaju <hj@mistral.ai>
This commit is contained in:
@@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
|
||||
sorted_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
|
||||
expert_ids_triton = torch.zeros(
|
||||
expert_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
|
||||
sorted_ids_vllm.fill_(topk_ids.numel())
|
||||
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
|
||||
expert_ids_vllm = torch.empty_like(expert_ids_triton)
|
||||
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
|
||||
|
||||
# 2. run implementations
|
||||
@@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
||||
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
Reference in New Issue
Block a user