[Bugfix] Fix Qwen3-Next in_proj_ba weight sharding with TP > 1 (#36242)
Signed-off-by: AjAnubolu <anuboluajay@gmail.com>
This commit is contained in:
@@ -145,6 +145,24 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def create_ba_proj(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_v_heads: int,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
) -> MergedColumnParallelLinear:
|
||||
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
|
||||
# checkpoint, which are loaded into the fused in_proj_ba parameter
|
||||
# via stacked_params_mapping with shard_id 0 and 1 respectively.
|
||||
return MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[num_v_heads] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -412,12 +412,11 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
prefix=f"{prefix}.in_proj_qkvz",
|
||||
)
|
||||
# ba_proj doesn't support blockwise fp8 quantization.
|
||||
# # in_proj_ba is defined as MergedColumnParallelLinear for
|
||||
# compatibility with Qwen3_5.
|
||||
self.in_proj_ba = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[self.num_v_heads] * 2,
|
||||
bias=False,
|
||||
# Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
|
||||
# layouts, so we use a factory method to create the projection.
|
||||
self.in_proj_ba = self.create_ba_proj(
|
||||
hidden_size=self.hidden_size,
|
||||
num_v_heads=self.num_v_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj_ba",
|
||||
)
|
||||
@@ -497,6 +496,28 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def create_ba_proj(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_v_heads: int,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
) -> MergedColumnParallelLinear:
|
||||
# Qwen3-Next stores in_proj_ba as a single fused weight with an
|
||||
# interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where
|
||||
# each group corresponds to a key-head group. We must use a single
|
||||
# output shard so that ColumnParallel sharding preserves this
|
||||
# interleaved structure across TP ranks.
|
||||
# Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since
|
||||
# its checkpoint has separate in_proj_b and in_proj_a weights.
|
||||
return MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[num_v_heads * 2],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def fix_query_key_value_ordering(
|
||||
self,
|
||||
mixed_qkvz: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user