[Misc]Add BNB quantization for MolmoForCausalLM (#11551)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-12-28 02:48:24 +08:00
committed by GitHub
parent 55509c2114
commit 0240402c46
2 changed files with 83 additions and 33 deletions

View File

@@ -461,30 +461,71 @@ class MolmoAttention(nn.Module):
return output
class MolmoMLP(nn.Module):
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
# Note that the order is reversed compared to
# SiluAndMul.
return x * F.silu(gate)
class LanuageModelMLP(nn.Module):
"""Molmo's LLM mlp."""
def __init__(self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
proj_name: str = "gate_up_proj") -> None:
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2
# Molmo's LLM proj weights are already merged into the disk, while
# image_projector proj is separate. If the same proj_name were used, it
# would create ambiguity and make it difficult to support BNB and LoRA.
self.proj_name = proj_name
setattr(
self, proj_name,
MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
))
self.gate_up_proj = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SwiGLU()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class ImageProjectorMLP(nn.Module):
"""Molmo's image_projector mlp."""
def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2
self.merged_linear = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SiluAndMul()
@@ -500,7 +541,7 @@ class MolmoMLP(nn.Module):
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = getattr(self, self.proj_name)(x)
gate_up, _ = self.merged_linear(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
@@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = MolmoMLP(config,
quant_config=quant_config,
proj_name="gate_up_proj")
self.mlp = LanuageModelMLP(config, quant_config=quant_config)
# LayerNorm
assert config.layer_norm_type == "rms"
@@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module):
vision_config,
nlayers=len(self.vit_layers),
quant_config=quant_config)
self.image_projector = MolmoMLP(
self.image_projector = ImageProjectorMLP(
config,
input_dim=vision_config.image_emb_dim,
quant_config=quant_config,
proj_name="merged_linear",
)
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
@@ -842,10 +880,6 @@ class MolmoModel(nn.Module):
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
},
)
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
"up_proj": ("merged_linear", 1),
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config