[Bugfix][Model] Fix PixtralForConditionalGeneration LoRA (#36963)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Jee Jee Li
2026-03-30 14:59:42 +08:00
committed by GitHub
parent 63babd17f1
commit ac30a8311e

View File

@@ -8,7 +8,6 @@ from typing import Annotated, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
@@ -26,16 +25,18 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.activation import SiluAndMul, get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
@@ -293,6 +294,23 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
class PixtralForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_encoder.",
"model.multi_modal_projector.": "vision_language_adapter.",
},
orig_to_new_substr={
".linear_1.": ".w_in.",
".linear_2.": ".w_out.",
},
)
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
@@ -325,7 +343,10 @@ class PixtralForConditionalGeneration(
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_encoder = VisionTransformer(
self.vision_args,
prefix=maybe_prefix(prefix, "vision_encoder"),
)
self.pre_mm_projector_norm = (
RMSNorm(self.vision_args.hidden_size, eps=1e-5)
if self.vision_args.add_pre_mm_projector_layer_norm
@@ -435,6 +456,15 @@ class PixtralForConditionalGeneration(
return self.language_model.get_eagle3_aux_hidden_state_layers()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
_vision_encoder_stacked_params = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith(("vision_encoder", "vision_tower"))
@@ -449,7 +479,6 @@ class PixtralForConditionalGeneration(
def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("pre_mm_projector_norm")
# Get references to parameters for direct loading
vision_encoder_dict = (
dict(self.vision_encoder.named_parameters())
if self.vision_encoder is not None
@@ -472,29 +501,41 @@ class PixtralForConditionalGeneration(
)
def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
if _is_layer_none_or_staged(self.vision_encoder):
continue
# Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_encoder_dict.get(trimmed_name)
if param is not None:
with torch.no_grad():
default_weight_loader(param, w)
for (
param_name,
weight_name,
shard_id,
) in _vision_encoder_stacked_params:
if weight_name in trimmed_name:
trimmed_name = trimmed_name.replace(weight_name, param_name)
param = vision_encoder_dict[trimmed_name]
weight_loader = param.weight_loader
weight_loader(param, w, shard_id)
break
else:
param = vision_encoder_dict.get(trimmed_name)
if param is not None:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, w)
elif is_patch_merger((name, w)):
if _is_layer_none_or_staged(self.patch_merger):
continue
# Load vision patch merger weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = patch_merger_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)):
if _is_layer_none_or_staged(self.pre_mm_projector_norm):
continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = pre_mm_projector_norm_dict[trimmed_name]
with torch.no_grad():
@@ -502,26 +543,23 @@ class PixtralForConditionalGeneration(
elif is_vision_lang_adapter_weights((name, w)):
if _is_layer_none_or_staged(self.vision_language_adapter):
continue
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_lang_adapter_dict.get(trimmed_name)
if param is not None:
with torch.no_grad():
default_weight_loader(param, w)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
# Strip "language_model." prefix if present (HF sharded format)
name = name.removeprefix("language_model.")
yield (name, w)
# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())
def get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="vision_language_adapter",
language_model="language_model.",
connector="vision_language_adapter.",
tower_model="vision_encoder",
)
@@ -614,29 +652,78 @@ def apply_rotary_emb_vit(
class FeedForward(nn.Module):
def __init__(self, args: VisionEncoderArgs):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
disable_tp: bool = False,
) -> None:
super().__init__()
assert args.intermediate_size is not None
self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.w13",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.w2",
)
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class Attention(nn.Module):
def __init__(self, args: VisionEncoderArgs):
def __init__(
self,
args: VisionEncoderArgs,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
disable_tp: bool = False,
):
super().__init__()
self.args = args
assert not args.hidden_size % args.num_attention_heads
self.n_heads = args.num_attention_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.qkv_proj = QKVParallelLinear(
hidden_size=args.hidden_size,
head_size=self.head_dim,
total_num_heads=args.num_attention_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wqkv",
disable_tp=disable_tp,
)
self.o_proj = RowParallelLinear(
input_size=args.hidden_size,
output_size=args.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wo",
disable_tp=disable_tp,
)
tp_size = 1 if disable_tp else get_tensor_model_parallel_world_size()
self.n_heads = divide(args.num_attention_heads, tp_size)
def forward(
self,
@@ -646,7 +733,8 @@ class Attention(nn.Module):
) -> torch.Tensor:
batch, patches, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
qkv, _ = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
@@ -663,14 +751,32 @@ class Attention(nn.Module):
out = out.transpose(1, 2)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
out, _ = self.o_proj(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, args: VisionEncoderArgs):
def __init__(
self,
args: VisionEncoderArgs,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
disable_tp: bool = False,
):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.attention = Attention(
args,
quant_config=quant_config,
prefix=f"{prefix}.attention",
disable_tp=disable_tp,
)
self.feed_forward = FeedForward(
args.hidden_size,
args.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
disable_tp=disable_tp,
)
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
@@ -690,11 +796,24 @@ class TransformerBlock(nn.Module):
class Transformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
def __init__(
self,
args: VisionEncoderArgs,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
disable_tp: bool = False,
):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(TransformerBlock(args))
for idx in range(args.num_hidden_layers):
self.layers.append(
TransformerBlock(
args,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
disable_tp=disable_tp,
)
)
def forward(
self,
@@ -727,9 +846,15 @@ def position_meshgrid(
class VisionTransformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
def __init__(
self,
args: VisionEncoderArgs,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.args = args
disable_tp = is_vit_use_data_parallel()
self.patch_conv = Conv2dLayer(
in_channels=args.num_channels,
out_channels=args.hidden_size,
@@ -738,7 +863,12 @@ class VisionTransformer(nn.Module):
bias=False,
)
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
self.transformer = Transformer(args)
self.transformer = Transformer(
args,
quant_config=quant_config,
prefix=f"{prefix}.transformer",
disable_tp=disable_tp,
)
head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
@@ -822,13 +952,16 @@ class VisionLanguageAdapter(nn.Module):
def __init__(self, args: VisionEncoderArgs, dim: int):
super().__init__()
assert isinstance(args, VisionEncoderArgs)
self.w_in = nn.Linear(
self.w_in = ReplicatedLinear(
args.hidden_size,
dim,
bias=args.adapter_bias,
return_bias=False,
)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
self.w_out = ReplicatedLinear(
dim, dim, bias=args.adapter_bias, return_bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
@@ -852,10 +985,8 @@ class PatchMerger(nn.Module):
self.spatial_merge_size = spatial_merge_size
self.mlp_input_dim = mlp_input_dim
self.merging_layer = nn.Linear(
mlp_input_dim,
vision_encoder_dim,
bias=use_mlp_bias,
self.merging_layer = ReplicatedLinear(
mlp_input_dim, vision_encoder_dim, bias=use_mlp_bias, return_bias=False
)
def forward(