diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index e5967c122..78dda9ff4 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -182,8 +182,8 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): # ============================================================ mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj( hidden_states, - self.in_proj_qkvz.weight.shape[0], - self.in_proj_ba.weight.shape[0], + sum(self.in_proj_qkvz.output_sizes) // self.tp_size, + sum(self.in_proj_ba.output_sizes) // self.tp_size, self.prefix, ) qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 7aaded7ae..bf59c0c11 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -660,8 +660,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # ============================================================ projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj( hidden_states, - self.in_proj_qkvz.weight.shape[0], - self.in_proj_ba.weight.shape[0], + sum(self.in_proj_qkvz.output_sizes) // self.tp_size, + sum(self.in_proj_ba.output_sizes) // self.tp_size, self.prefix, ) query, key, value, z, b, a = self.fix_query_key_value_ordering(