[Model] Molmo2: Enable quantized weight mapping for vision backbone (#32385)
Signed-off-by: kimheesu <wlskaka4@gmail.com>
This commit is contained in:
@@ -67,13 +67,13 @@ from vllm.multimodal.parse import (
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import round_down
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -295,6 +295,7 @@ class ViTMLP(nn.Module):
|
||||
hidden_dim: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.w1 = ColumnParallelLinear(
|
||||
@@ -302,6 +303,7 @@ class ViTMLP(nn.Module):
|
||||
hidden_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w1",
|
||||
)
|
||||
# Activation function.
|
||||
self.act = get_act_fn(hidden_act)
|
||||
@@ -310,6 +312,7 @@ class ViTMLP(nn.Module):
|
||||
dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -364,12 +367,14 @@ class ViTMultiHeadDotProductAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_qkv",
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wo",
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attn = MMEncoderAttention(
|
||||
@@ -414,6 +419,7 @@ class Molmo2VisionBlock(nn.Module):
|
||||
hidden_dim=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.attention_norm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
@@ -548,6 +554,7 @@ class ImagePoolingAttention(nn.Module):
|
||||
use_bias: bool = True,
|
||||
use_pytorch_sdpa: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -579,18 +586,21 @@ class ImagePoolingAttention(nn.Module):
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
)
|
||||
self.merged_kv = MergedColumnParallelLinear(
|
||||
self.input_dim,
|
||||
[self.total_num_kv_heads * self.head_dim] * 2,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_kv",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.use_pytorch_sdpa = use_pytorch_sdpa
|
||||
@@ -672,6 +682,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
output_dim: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -680,6 +691,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
[hidden_dim] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_linear",
|
||||
)
|
||||
# Activation function.
|
||||
assert hidden_act == "silu"
|
||||
@@ -691,6 +703,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
output_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -745,6 +758,7 @@ class Molmo2VisionBackbone(nn.Module, SupportsQuant):
|
||||
head_dim=adapter_config.head_dim,
|
||||
use_pytorch_sdpa=adapter_config.pooling_attention_mask,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.image_pooling_2d",
|
||||
)
|
||||
self.image_projector = ImageProjectorMLP(
|
||||
input_dim=adapter_config.hidden_size,
|
||||
@@ -752,6 +766,7 @@ class Molmo2VisionBackbone(nn.Module, SupportsQuant):
|
||||
output_dim=adapter_config.text_hidden_size,
|
||||
hidden_act=adapter_config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.image_projector",
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -2438,13 +2453,13 @@ class Molmo2ForConditionalGeneration(
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
# vision backbone mapping
|
||||
"image_pooling_2d.wq.": "image_pooling_2d.q_proj.",
|
||||
"image_pooling_2d.wk.": "image_pooling_2d.k_proj.",
|
||||
"image_pooling_2d.wv.": "image_pooling_2d.v_proj.",
|
||||
"image_pooling_2d.wo.": "image_pooling_2d.o_proj.",
|
||||
"image_projector.w1.": "image_projector.gate_proj.",
|
||||
"image_projector.w3.": "image_projector.up_proj.",
|
||||
"image_projector.w2.": "image_projector.down_proj.",
|
||||
"image_pooling_2d.wq": "image_pooling_2d.q_proj",
|
||||
"image_pooling_2d.wk": "image_pooling_2d.k_proj",
|
||||
"image_pooling_2d.wv": "image_pooling_2d.v_proj",
|
||||
"image_pooling_2d.wo": "image_pooling_2d.o_proj",
|
||||
"image_projector.w1": "image_projector.gate_proj",
|
||||
"image_projector.w3": "image_projector.up_proj",
|
||||
"image_projector.w2": "image_projector.down_proj",
|
||||
# language backbone mapping
|
||||
"att_proj": "qkv_proj",
|
||||
"attn_out": "o_proj",
|
||||
|
||||
Reference in New Issue
Block a user