[ROCm] Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5/Linear) (#35850)

Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
Chuan (Richard) Li
2026-03-06 12:24:03 -08:00
committed by GitHub
parent 225d1090a0
commit c188749bcd

View File

@@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_sharing_target_layer_name,
**mla_args,
)
assert num_heads == 16 or num_heads == 128, (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
_valid_heads = num_heads in (4, 8) or (
num_heads % 16 == 0 and 16 <= num_heads <= 128
)
assert _valid_heads, (
f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
f"in [16, 128].\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value."
)
self._needs_head_repeat = num_heads < 16
self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
@@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
assert isinstance(q, torch.Tensor)
B = q.shape[0]
if self._needs_head_repeat:
q = q.repeat_interleave(self._head_repeat_factor, dim=1)
kernel_num_heads = 16
else:
kernel_num_heads = self.num_heads
o = torch.zeros(
B,
self.num_heads,
kernel_num_heads,
self.kv_lora_rank,
dtype=attn_metadata.decode.attn_out_dtype,
device=q.device,
@@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_scale=layer._k_scale,
)
if self._needs_head_repeat:
o = o[:, :: self._head_repeat_factor, :]
return o, None