diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 85f455101..2a5b49282 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -145,6 +145,24 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): prefix=prefix, ) + def create_ba_proj( + self, + hidden_size: int, + num_v_heads: int, + quant_config: QuantizationConfig | None, + prefix: str, + ) -> MergedColumnParallelLinear: + # Qwen3.5 has separate in_proj_b and in_proj_a weights in the + # checkpoint, which are loaded into the fused in_proj_ba parameter + # via stacked_params_mapping with shard_id 0 and 1 respectively. + return MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[num_v_heads] * 2, + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 4c4ff0ccf..343f58be9 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -412,12 +412,11 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): prefix=f"{prefix}.in_proj_qkvz", ) # ba_proj doesn't support blockwise fp8 quantization. - # # in_proj_ba is defined as MergedColumnParallelLinear for - # compatibility with Qwen3_5. - self.in_proj_ba = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.num_v_heads] * 2, - bias=False, + # Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint + # layouts, so we use a factory method to create the projection. + self.in_proj_ba = self.create_ba_proj( + hidden_size=self.hidden_size, + num_v_heads=self.num_v_heads, quant_config=quant_config, prefix=f"{prefix}.in_proj_ba", ) @@ -497,6 +496,28 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): prefix=prefix, ) + def create_ba_proj( + self, + hidden_size: int, + num_v_heads: int, + quant_config: QuantizationConfig | None, + prefix: str, + ) -> MergedColumnParallelLinear: + # Qwen3-Next stores in_proj_ba as a single fused weight with an + # interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where + # each group corresponds to a key-head group. We must use a single + # output shard so that ColumnParallel sharding preserves this + # interleaved structure across TP ranks. + # Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since + # its checkpoint has separate in_proj_b and in_proj_a weights. + return MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[num_v_heads * 2], + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + def fix_query_key_value_ordering( self, mixed_qkvz: torch.Tensor,