[ROCm] Fix aiter persistent mode mla with q/o nhead<16 for kimi-k2.5 tp8 (#38615)
Signed-off-by: wufann <36477220+wufann@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -129,9 +129,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
from aiter import dtypes, get_mla_metadata_info_v1
|
||||
|
||||
self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
# For num_attention_heads < 16 (e.g. kimi-k2.5 head=8 with TP8),
|
||||
# make sure get_mla_metadata_info_v1 / get_mla_metadata_v1 are consistent
|
||||
# with the actual tensor shape passed to mla_decode_fwd.
|
||||
self._num_attention_heads = max(16, self.num_heads)
|
||||
q_dtype = self.decode_attn_out_dtype
|
||||
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
|
||||
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
|
||||
Reference in New Issue
Block a user