Add 320 dimension size support to MLA (#36161)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
@@ -23,7 +23,7 @@ CACHE_LAYOUTS = ["NHD", "HND"]
|
||||
KV_SCALE_TYPES = ["tensor", "attn_head"]
|
||||
|
||||
# Parameters for MLA tests.
|
||||
KV_LORA_RANKS = [512]
|
||||
KV_LORA_RANKS = [256, 512]
|
||||
QK_ROPE_HEAD_DIMS = [64]
|
||||
NUM_TOKENS_MLA = [42]
|
||||
BLOCK_SIZES_MLA = [16]
|
||||
@@ -627,6 +627,8 @@ def test_concat_and_cache_ds_mla(
|
||||
pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
|
||||
if dtype.itemsize != 2:
|
||||
pytest.skip("ds_mla only supports 16-bit input")
|
||||
if kv_lora_rank != 512:
|
||||
pytest.skip("fp8_ds_mla requires kv_lora_rank == 512")
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
@@ -663,7 +665,8 @@ def test_concat_and_cache_ds_mla(
|
||||
ref_cache_32bit = ref_cache_slice.view(torch.float32)
|
||||
|
||||
kv_c_data = kv_c[i]
|
||||
for tile_idx in range(4):
|
||||
num_tiles = kv_lora_rank // 128
|
||||
for tile_idx in range(num_tiles):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = (tile_idx + 1) * 128
|
||||
tile_data[:] = kv_c_data[tile_start:tile_end]
|
||||
|
||||
Reference in New Issue
Block a user