[Bugfix] Fix shape mismatch assertion error when loading Gemma3n model with BitsAndBytes quantization (#21808)

Signed-off-by: sydarb <areebsyed237@gmail.com>
This commit is contained in:
Areeb Syed
2025-07-30 09:05:21 +05:30
committed by GitHub
parent b917da442b
commit fdde18229e

View File

@@ -167,22 +167,33 @@ class Gemma3nAltUp(nn.Module):
class Gemma3nLaurelBlock(nn.Module): class Gemma3nLaurelBlock(nn.Module):
"""Learned Augmented Residual Layer""" """Learned Augmented Residual Layer"""
def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float, def __init__(
prefix: str): self,
hidden_size: int,
laurel_rank: int,
rms_norm_eps: float,
*,
quant_config: Optional[QuantizationConfig] = None,
prefix: str,
) -> None:
super().__init__() super().__init__()
self.linear_left = ColumnParallelLinear( self.linear_left = ColumnParallelLinear(
hidden_size, hidden_size,
laurel_rank, laurel_rank,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_left", prefix=f"{prefix}.linear_left",
return_bias=False, return_bias=False,
) )
self.linear_right = RowParallelLinear(laurel_rank, self.linear_right = RowParallelLinear(
laurel_rank,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_right", prefix=f"{prefix}.linear_right",
return_bias=False) return_bias=False,
)
self.post_laurel_norm = RMSNorm( self.post_laurel_norm = RMSNorm(
hidden_size=hidden_size, hidden_size=hidden_size,
eps=rms_norm_eps, eps=rms_norm_eps,
@@ -417,6 +428,7 @@ class Gemma3nDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
laurel_rank=config.laurel_rank, laurel_rank=config.laurel_rank,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.laurel", prefix=f"{prefix}.laurel",
) )
@@ -427,6 +439,7 @@ class Gemma3nDecoderLayer(nn.Module):
config.hidden_size, config.hidden_size,
config.hidden_size_per_layer_input, config.hidden_size_per_layer_input,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_input_gate", prefix=f"{prefix}.per_layer_input_gate",
return_bias=False, return_bias=False,
) )
@@ -434,6 +447,7 @@ class Gemma3nDecoderLayer(nn.Module):
config.hidden_size_per_layer_input, config.hidden_size_per_layer_input,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_projection", prefix=f"{prefix}.per_layer_projection",
return_bias=False, return_bias=False,
) )
@@ -547,6 +561,7 @@ class Gemma3nTextModel(nn.Module):
bias=False, bias=False,
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_model_projection", prefix=f"{prefix}.per_layer_model_projection",
) )
self.per_layer_projection_norm = RMSNorm( self.per_layer_projection_norm = RMSNorm(
@@ -566,6 +581,7 @@ class Gemma3nTextModel(nn.Module):
bias=False, bias=False,
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_projections", prefix=f"{prefix}.{idx-1}.altup_projections",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])
@@ -576,6 +592,7 @@ class Gemma3nTextModel(nn.Module):
bias=False, bias=False,
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_unembed_projections", prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])