[Misc]Add BNB quantization for MolmoForCausalLM (#11551)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -461,30 +461,71 @@ class MolmoAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class MolmoMLP(nn.Module):
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
# Note that the order is reversed compared to
|
||||
# SiluAndMul.
|
||||
return x * F.silu(gate)
|
||||
|
||||
|
||||
class LanuageModelMLP(nn.Module):
|
||||
"""Molmo's LLM mlp."""
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
input_dim: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
proj_name: str = "gate_up_proj") -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size // 2
|
||||
|
||||
# Molmo's LLM proj weights are already merged into the disk, while
|
||||
# image_projector proj is separate. If the same proj_name were used, it
|
||||
# would create ambiguity and make it difficult to support BNB and LoRA.
|
||||
self.proj_name = proj_name
|
||||
setattr(
|
||||
self, proj_name,
|
||||
MergedColumnParallelLinear(
|
||||
input_dim or self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
))
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_dim or self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# Activation function.
|
||||
self.act_fn = SwiGLU()
|
||||
# Feed-forward output projection.
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ImageProjectorMLP(nn.Module):
|
||||
"""Molmo's image_projector mlp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
input_dim: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size // 2
|
||||
|
||||
self.merged_linear = MergedColumnParallelLinear(
|
||||
input_dim or self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# Activation function.
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@@ -500,7 +541,7 @@ class MolmoMLP(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
gate_up, _ = getattr(self, self.proj_name)(x)
|
||||
gate_up, _ = self.merged_linear(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
@@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn")
|
||||
|
||||
# MLP block.
|
||||
self.mlp = MolmoMLP(config,
|
||||
quant_config=quant_config,
|
||||
proj_name="gate_up_proj")
|
||||
self.mlp = LanuageModelMLP(config, quant_config=quant_config)
|
||||
|
||||
# LayerNorm
|
||||
assert config.layer_norm_type == "rms"
|
||||
@@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module):
|
||||
vision_config,
|
||||
nlayers=len(self.vit_layers),
|
||||
quant_config=quant_config)
|
||||
self.image_projector = MolmoMLP(
|
||||
self.image_projector = ImageProjectorMLP(
|
||||
config,
|
||||
input_dim=vision_config.image_emb_dim,
|
||||
quant_config=quant_config,
|
||||
proj_name="merged_linear",
|
||||
)
|
||||
|
||||
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
|
||||
@@ -842,10 +880,6 @@ class MolmoModel(nn.Module):
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "gate_up_proj" in name:
|
||||
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
|
||||
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
},
|
||||
)
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
"gate_proj": ("merged_linear", 0),
|
||||
"up_proj": ("merged_linear", 1),
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
Reference in New Issue
Block a user