[Model] Molmo2: Enable quantized weight mapping for vision backbone (#32385)

Signed-off-by: kimheesu <wlskaka4@gmail.com>
This commit is contained in:
Kim Hee Su
2026-01-17 18:33:05 +09:00
committed by GitHub
parent d3317bbba4
commit 1646fea672

View File

@@ -67,13 +67,13 @@ from vllm.multimodal.parse import (
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import round_down
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -295,6 +295,7 @@ class ViTMLP(nn.Module):
hidden_dim: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.w1 = ColumnParallelLinear(
@@ -302,6 +303,7 @@ class ViTMLP(nn.Module):
hidden_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w1",
)
# Activation function.
self.act = get_act_fn(hidden_act)
@@ -310,6 +312,7 @@ class ViTMLP(nn.Module):
dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -364,12 +367,14 @@ class ViTMultiHeadDotProductAttention(nn.Module):
self.total_num_kv_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.merged_qkv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wo",
)
self.scale = self.head_dim**-0.5
self.attn = MMEncoderAttention(
@@ -414,6 +419,7 @@ class Molmo2VisionBlock(nn.Module):
hidden_dim=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.attention_norm = nn.LayerNorm(
config.hidden_size,
@@ -548,6 +554,7 @@ class ImagePoolingAttention(nn.Module):
use_bias: bool = True,
use_pytorch_sdpa: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -579,18 +586,21 @@ class ImagePoolingAttention(nn.Module):
self.total_num_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.merged_kv = MergedColumnParallelLinear(
self.input_dim,
[self.total_num_kv_heads * self.head_dim] * 2,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.merged_kv",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.scale = self.head_dim**-0.5
self.use_pytorch_sdpa = use_pytorch_sdpa
@@ -672,6 +682,7 @@ class ImageProjectorMLP(nn.Module):
output_dim: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -680,6 +691,7 @@ class ImageProjectorMLP(nn.Module):
[hidden_dim] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.merged_linear",
)
# Activation function.
assert hidden_act == "silu"
@@ -691,6 +703,7 @@ class ImageProjectorMLP(nn.Module):
output_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -745,6 +758,7 @@ class Molmo2VisionBackbone(nn.Module, SupportsQuant):
head_dim=adapter_config.head_dim,
use_pytorch_sdpa=adapter_config.pooling_attention_mask,
quant_config=quant_config,
prefix=f"{prefix}.image_pooling_2d",
)
self.image_projector = ImageProjectorMLP(
input_dim=adapter_config.hidden_size,
@@ -752,6 +766,7 @@ class Molmo2VisionBackbone(nn.Module, SupportsQuant):
output_dim=adapter_config.text_hidden_size,
hidden_act=adapter_config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.image_projector",
)
@property
@@ -2438,13 +2453,13 @@ class Molmo2ForConditionalGeneration(
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_pooling_2d.wq.": "image_pooling_2d.q_proj.",
"image_pooling_2d.wk.": "image_pooling_2d.k_proj.",
"image_pooling_2d.wv.": "image_pooling_2d.v_proj.",
"image_pooling_2d.wo.": "image_pooling_2d.o_proj.",
"image_projector.w1.": "image_projector.gate_proj.",
"image_projector.w3.": "image_projector.up_proj.",
"image_projector.w2.": "image_projector.down_proj.",
"image_pooling_2d.wq": "image_pooling_2d.q_proj",
"image_pooling_2d.wk": "image_pooling_2d.k_proj",
"image_pooling_2d.wv": "image_pooling_2d.v_proj",
"image_pooling_2d.wo": "image_pooling_2d.o_proj",
"image_projector.w1": "image_projector.gate_proj",
"image_projector.w3": "image_projector.up_proj",
"image_projector.w2": "image_projector.down_proj",
# language backbone mapping
"att_proj": "qkv_proj",
"attn_out": "o_proj",