Add SM100 kernels (#201)

Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Simon Mo
2025-09-29 02:07:28 -07:00
committed by GitHub
parent 80ceeb2c76
commit 59f2c07cf2
6 changed files with 808 additions and 10 deletions

View File

@@ -204,7 +204,7 @@ def test_paged_mqa_logits():
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, 132)
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms())
logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True)
ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len)
@@ -229,7 +229,7 @@ def test_paged_mqa_logits():
('fp8_paged_mqa_logits', 'clean_logits'))
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: '
f'{tflops / t:3.0f} TFLOPS, {t * 1e6:3.0f} us, '
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, '
f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s | '
f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s')
print()
@@ -243,6 +243,5 @@ if __name__ == '__main__':
test_gemm_skip_head_mid()
if get_arch_major() == 9:
test_mqa_logits()
test_paged_mqa_logits()
test_mqa_logits()
test_paged_mqa_logits()