[Bugfix] Fix shape mismatch assertion error when loading Gemma3n model with BitsAndBytes quantization (#21808)
Signed-off-by: sydarb <areebsyed237@gmail.com>
This commit is contained in:
@@ -167,22 +167,33 @@ class Gemma3nAltUp(nn.Module):
|
|||||||
class Gemma3nLaurelBlock(nn.Module):
|
class Gemma3nLaurelBlock(nn.Module):
|
||||||
"""Learned Augmented Residual Layer"""
|
"""Learned Augmented Residual Layer"""
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float,
|
def __init__(
|
||||||
prefix: str):
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
laurel_rank: int,
|
||||||
|
rms_norm_eps: float,
|
||||||
|
*,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.linear_left = ColumnParallelLinear(
|
self.linear_left = ColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
laurel_rank,
|
laurel_rank,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.linear_left",
|
prefix=f"{prefix}.linear_left",
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
)
|
)
|
||||||
self.linear_right = RowParallelLinear(laurel_rank,
|
self.linear_right = RowParallelLinear(
|
||||||
|
laurel_rank,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.linear_right",
|
prefix=f"{prefix}.linear_right",
|
||||||
return_bias=False)
|
return_bias=False,
|
||||||
|
)
|
||||||
self.post_laurel_norm = RMSNorm(
|
self.post_laurel_norm = RMSNorm(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
eps=rms_norm_eps,
|
eps=rms_norm_eps,
|
||||||
@@ -417,6 +428,7 @@ class Gemma3nDecoderLayer(nn.Module):
|
|||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
laurel_rank=config.laurel_rank,
|
laurel_rank=config.laurel_rank,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.laurel",
|
prefix=f"{prefix}.laurel",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -427,6 +439,7 @@ class Gemma3nDecoderLayer(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.hidden_size_per_layer_input,
|
config.hidden_size_per_layer_input,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.per_layer_input_gate",
|
prefix=f"{prefix}.per_layer_input_gate",
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
)
|
)
|
||||||
@@ -434,6 +447,7 @@ class Gemma3nDecoderLayer(nn.Module):
|
|||||||
config.hidden_size_per_layer_input,
|
config.hidden_size_per_layer_input,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.per_layer_projection",
|
prefix=f"{prefix}.per_layer_projection",
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
)
|
)
|
||||||
@@ -547,6 +561,7 @@ class Gemma3nTextModel(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.per_layer_model_projection",
|
prefix=f"{prefix}.per_layer_model_projection",
|
||||||
)
|
)
|
||||||
self.per_layer_projection_norm = RMSNorm(
|
self.per_layer_projection_norm = RMSNorm(
|
||||||
@@ -566,6 +581,7 @@ class Gemma3nTextModel(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.{idx-1}.altup_projections",
|
prefix=f"{prefix}.{idx-1}.altup_projections",
|
||||||
) for idx in range(1, self.config.altup_num_inputs)
|
) for idx in range(1, self.config.altup_num_inputs)
|
||||||
])
|
])
|
||||||
@@ -576,6 +592,7 @@ class Gemma3nTextModel(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
|
prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
|
||||||
) for idx in range(1, self.config.altup_num_inputs)
|
) for idx in range(1, self.config.altup_num_inputs)
|
||||||
])
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user