diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 6c1073b3a..8b764cd62 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -129,9 +129,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): from aiter import dtypes, get_mla_metadata_info_v1 - self._num_attention_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config - ) + # For num_attention_heads < 16 (e.g. kimi-k2.5 head=8 with TP8), + # make sure get_mla_metadata_info_v1 / get_mla_metadata_v1 are consistent + # with the actual tensor shape passed to mla_decode_fwd. + self._num_attention_heads = max(16, self.num_heads) q_dtype = self.decode_attn_out_dtype kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto") if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):