From 5c3bae1a6a73aad8c0883a097079448b506fbfcc Mon Sep 17 00:00:00 2001 From: ant-yy Date: Wed, 15 Oct 2025 16:44:04 +0800 Subject: [PATCH] [Fix] Remove divisibility requirement between num_kv_heads and tp_size in bailing_moe (#26876) Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index a7f3ebed6..1549c6534 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -86,13 +86,12 @@ class BailingAttention(nn.Module): tp_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_size == 0 - assert self.total_kv_heads % tp_size == 0 assert self.total_num_heads >= self.total_kv_heads self.num_heads = self.total_num_heads // tp_size self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads - self.num_kv_heads = self.total_kv_heads // tp_size + self.num_kv_heads = max(1, self.total_kv_heads // tp_size) self.kv_size_per_rank = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 self.use_qk_norm = getattr(config, "use_qk_norm", False)