@@ -204,7 +204,7 @@ def test_paged_mqa_logits():
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, 132)
|
||||
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms())
|
||||
logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True)
|
||||
|
||||
ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len)
|
||||
@@ -229,7 +229,7 @@ def test_paged_mqa_logits():
|
||||
('fp8_paged_mqa_logits', 'clean_logits'))
|
||||
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
|
||||
print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: '
|
||||
f'{tflops / t:3.0f} TFLOPS, {t * 1e6:3.0f} us, '
|
||||
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, '
|
||||
f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s | '
|
||||
f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s')
|
||||
print()
|
||||
@@ -243,6 +243,5 @@ if __name__ == '__main__':
|
||||
|
||||
test_gemm_skip_head_mid()
|
||||
|
||||
if get_arch_major() == 9:
|
||||
test_mqa_logits()
|
||||
test_paged_mqa_logits()
|
||||
test_mqa_logits()
|
||||
test_paged_mqa_logits()
|
||||
|
||||
Reference in New Issue
Block a user