[ROCm] Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5/Linear) (#35850)
Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
committed by
GitHub
parent
225d1090a0
commit
c188749bcd
@@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
assert num_heads == 16 or num_heads == 128, (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
_valid_heads = num_heads in (4, 8) or (
|
||||
num_heads % 16 == 0 and 16 <= num_heads <= 128
|
||||
)
|
||||
assert _valid_heads, (
|
||||
f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
|
||||
f"in [16, 128].\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value."
|
||||
)
|
||||
self._needs_head_repeat = num_heads < 16
|
||||
self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
@@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
|
||||
if self._needs_head_repeat:
|
||||
q = q.repeat_interleave(self._head_repeat_factor, dim=1)
|
||||
kernel_num_heads = 16
|
||||
else:
|
||||
kernel_num_heads = self.num_heads
|
||||
|
||||
o = torch.zeros(
|
||||
B,
|
||||
self.num_heads,
|
||||
kernel_num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=attn_metadata.decode.attn_out_dtype,
|
||||
device=q.device,
|
||||
@@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
kv_scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if self._needs_head_repeat:
|
||||
o = o[:, :: self._head_repeat_factor, :]
|
||||
|
||||
return o, None
|
||||
|
||||
Reference in New Issue
Block a user