diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index dde1fb3eb..6dbdd7dcb 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -221,11 +221,17 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): kv_sharing_target_layer_name, **mla_args, ) - assert num_heads == 16 or num_heads == 128, ( - f"Aiter MLA only supports 16 or 128 number of heads.\n" + _valid_heads = num_heads in (4, 8) or ( + num_heads % 16 == 0 and 16 <= num_heads <= 128 + ) + assert _valid_heads, ( + f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 " + f"in [16, 128].\n" f"Provided {num_heads} number of heads.\n" "Try adjusting tensor_parallel_size value." ) + self._needs_head_repeat = num_heads < 16 + self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1 unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( @@ -267,9 +273,16 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): assert isinstance(q, torch.Tensor) B = q.shape[0] + + if self._needs_head_repeat: + q = q.repeat_interleave(self._head_repeat_factor, dim=1) + kernel_num_heads = 16 + else: + kernel_num_heads = self.num_heads + o = torch.zeros( B, - self.num_heads, + kernel_num_heads, self.kv_lora_rank, dtype=attn_metadata.decode.attn_out_dtype, device=q.device, @@ -291,4 +304,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): kv_scale=layer._k_scale, ) + if self._needs_head_repeat: + o = o[:, :: self._head_repeat_factor, :] + return o, None