From 1b117cb0ac51b84dd7aed364e6e802dea5147ca6 Mon Sep 17 00:00:00 2001 From: wufann <36477220+wufann@users.noreply.github.com> Date: Fri, 3 Apr 2026 18:54:00 +0800 Subject: [PATCH] [ROCm] Fix aiter persistent mode mla with q/o nhead<16 for kimi-k2.5 tp8 (#38615) Signed-off-by: wufann <36477220+wufann@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 6c1073b3a..8b764cd62 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -129,9 +129,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): from aiter import dtypes, get_mla_metadata_info_v1 - self._num_attention_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config - ) + # For num_attention_heads < 16 (e.g. kimi-k2.5 head=8 with TP8), + # make sure get_mla_metadata_info_v1 / get_mla_metadata_v1 are consistent + # with the actual tensor shape passed to mla_decode_fwd. + self._num_attention_heads = max(16, self.num_heads) q_dtype = self.decode_attn_out_dtype kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto") if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):