[NVIDIA] Add Cutlass MLA backend (#17625)

This commit is contained in:
Kaixi Hou
2025-06-04 12:40:26 +08:00
committed by GitHub
parent 8d646c2e53
commit 41aa578428
7 changed files with 111 additions and 3 deletions

View File

@@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
q = torch.randn(bs, h_q, d)
# Amplify input values to ensure test coverage of edge cases where CUTLASS
# kernel errors occur with split_k settings.
q = torch.randn(bs, h_q, d) * 100
block_table = torch.randint(0,
bs * block_num, (bs, block_num),
dtype=torch.int32)