[NVIDIA] Add Cutlass MLA backend (#17625)
This commit is contained in:
@@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
q = torch.randn(bs, h_q, d)
|
||||
# Amplify input values to ensure test coverage of edge cases where CUTLASS
|
||||
# kernel errors occur with split_k settings.
|
||||
q = torch.randn(bs, h_q, d) * 100
|
||||
block_table = torch.randint(0,
|
||||
bs * block_num, (bs, block_num),
|
||||
dtype=torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user