[Bugfix] Fix Qwen3-Next in_proj_ba weight sharding with TP > 1 (#36242)

Signed-off-by: AjAnubolu <anuboluajay@gmail.com>
This commit is contained in:
Ajay Anubolu
2026-03-09 19:16:26 -07:00
committed by GitHub
parent 179547d62c
commit 4e95ec111c
2 changed files with 45 additions and 6 deletions

View File

@@ -145,6 +145,24 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
prefix=prefix,
)
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
# checkpoint, which are loaded into the fused in_proj_ba parameter
# via stacked_params_mapping with shard_id 0 and 1 respectively.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads] * 2,
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -412,12 +412,11 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix=f"{prefix}.in_proj_qkvz",
)
# ba_proj doesn't support blockwise fp8 quantization.
# # in_proj_ba is defined as MergedColumnParallelLinear for
# compatibility with Qwen3_5.
self.in_proj_ba = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.num_v_heads] * 2,
bias=False,
# Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
# layouts, so we use a factory method to create the projection.
self.in_proj_ba = self.create_ba_proj(
hidden_size=self.hidden_size,
num_v_heads=self.num_v_heads,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba",
)
@@ -497,6 +496,28 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix=prefix,
)
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3-Next stores in_proj_ba as a single fused weight with an
# interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where
# each group corresponds to a key-head group. We must use a single
# output shard so that ColumnParallel sharding preserves this
# interleaved structure across TP ranks.
# Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since
# its checkpoint has separate in_proj_b and in_proj_a weights.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads * 2],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def fix_query_key_value_ordering(
self,
mixed_qkvz: torch.Tensor,