[Model] Molmo2: Enable quantized weight mapping for vision backbone (#32385)

Signed-off-by: kimheesu <wlskaka4@gmail.com>
This commit is contained in:
Kim Hee Su
2026-01-17 18:33:05 +09:00
committed by GitHub
parent d3317bbba4
commit 1646fea672

View File

@@ -67,13 +67,13 @@ from vllm.multimodal.parse import (
MultiModalDataParser, MultiModalDataParser,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import round_down from vllm.utils.math_utils import round_down
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -295,6 +295,7 @@ class ViTMLP(nn.Module):
hidden_dim: int, hidden_dim: int,
hidden_act: str, hidden_act: str,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.w1 = ColumnParallelLinear( self.w1 = ColumnParallelLinear(
@@ -302,6 +303,7 @@ class ViTMLP(nn.Module):
hidden_dim, hidden_dim,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.w1",
) )
# Activation function. # Activation function.
self.act = get_act_fn(hidden_act) self.act = get_act_fn(hidden_act)
@@ -310,6 +312,7 @@ class ViTMLP(nn.Module):
dim, dim,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.w2",
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -364,12 +367,14 @@ class ViTMultiHeadDotProductAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=use_bias, bias=use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.merged_qkv",
) )
self.wo = RowParallelLinear( self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=use_bias, bias=use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.wo",
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.attn = MMEncoderAttention( self.attn = MMEncoderAttention(
@@ -414,6 +419,7 @@ class Molmo2VisionBlock(nn.Module):
hidden_dim=config.intermediate_size, hidden_dim=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
) )
self.attention_norm = nn.LayerNorm( self.attention_norm = nn.LayerNorm(
config.hidden_size, config.hidden_size,
@@ -548,6 +554,7 @@ class ImagePoolingAttention(nn.Module):
use_bias: bool = True, use_bias: bool = True,
use_pytorch_sdpa: bool = False, use_pytorch_sdpa: bool = False,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -579,18 +586,21 @@ class ImagePoolingAttention(nn.Module):
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
bias=use_bias, bias=use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj",
) )
self.merged_kv = MergedColumnParallelLinear( self.merged_kv = MergedColumnParallelLinear(
self.input_dim, self.input_dim,
[self.total_num_kv_heads * self.head_dim] * 2, [self.total_num_kv_heads * self.head_dim] * 2,
bias=use_bias, bias=use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.merged_kv",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=use_bias, bias=use_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.use_pytorch_sdpa = use_pytorch_sdpa self.use_pytorch_sdpa = use_pytorch_sdpa
@@ -672,6 +682,7 @@ class ImageProjectorMLP(nn.Module):
output_dim: int, output_dim: int,
hidden_act: str, hidden_act: str,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -680,6 +691,7 @@ class ImageProjectorMLP(nn.Module):
[hidden_dim] * 2, [hidden_dim] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.merged_linear",
) )
# Activation function. # Activation function.
assert hidden_act == "silu" assert hidden_act == "silu"
@@ -691,6 +703,7 @@ class ImageProjectorMLP(nn.Module):
output_dim, output_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj",
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -745,6 +758,7 @@ class Molmo2VisionBackbone(nn.Module, SupportsQuant):
head_dim=adapter_config.head_dim, head_dim=adapter_config.head_dim,
use_pytorch_sdpa=adapter_config.pooling_attention_mask, use_pytorch_sdpa=adapter_config.pooling_attention_mask,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.image_pooling_2d",
) )
self.image_projector = ImageProjectorMLP( self.image_projector = ImageProjectorMLP(
input_dim=adapter_config.hidden_size, input_dim=adapter_config.hidden_size,
@@ -752,6 +766,7 @@ class Molmo2VisionBackbone(nn.Module, SupportsQuant):
output_dim=adapter_config.text_hidden_size, output_dim=adapter_config.text_hidden_size,
hidden_act=adapter_config.hidden_act, hidden_act=adapter_config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.image_projector",
) )
@property @property
@@ -2438,13 +2453,13 @@ class Molmo2ForConditionalGeneration(
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
# vision backbone mapping # vision backbone mapping
"image_pooling_2d.wq.": "image_pooling_2d.q_proj.", "image_pooling_2d.wq": "image_pooling_2d.q_proj",
"image_pooling_2d.wk.": "image_pooling_2d.k_proj.", "image_pooling_2d.wk": "image_pooling_2d.k_proj",
"image_pooling_2d.wv.": "image_pooling_2d.v_proj.", "image_pooling_2d.wv": "image_pooling_2d.v_proj",
"image_pooling_2d.wo.": "image_pooling_2d.o_proj.", "image_pooling_2d.wo": "image_pooling_2d.o_proj",
"image_projector.w1.": "image_projector.gate_proj.", "image_projector.w1": "image_projector.gate_proj",
"image_projector.w3.": "image_projector.up_proj.", "image_projector.w3": "image_projector.up_proj",
"image_projector.w2.": "image_projector.down_proj.", "image_projector.w2": "image_projector.down_proj",
# language backbone mapping # language backbone mapping
"att_proj": "qkv_proj", "att_proj": "qkv_proj",
"attn_out": "o_proj", "attn_out": "o_proj",