diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 0d891b8c9..e179638a8 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -8,7 +8,6 @@ from typing import Annotated, Literal import torch import torch.nn as nn -import torch.nn.functional as F from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest @@ -26,16 +25,18 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import MultiModalDataDict -from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.activation import SiluAndMul, get_act_and_mul_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import WeightsMapper from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import ( MultiModalFieldConfig, @@ -293,6 +294,23 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) class PixtralForConditionalGeneration( nn.Module, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP ): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_encoder.", + "model.multi_modal_projector.": "vision_language_adapter.", + }, + orig_to_new_substr={ + ".linear_1.": ".w_in.", + ".linear_2.": ".w_out.", + }, + ) + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): @@ -325,7 +343,10 @@ class PixtralForConditionalGeneration( ) with self._mark_tower_model(vllm_config, "image"): - self.vision_encoder = VisionTransformer(self.vision_args) + self.vision_encoder = VisionTransformer( + self.vision_args, + prefix=maybe_prefix(prefix, "vision_encoder"), + ) self.pre_mm_projector_norm = ( RMSNorm(self.vision_args.hidden_size, eps=1e-5) if self.vision_args.add_pre_mm_projector_layer_norm @@ -435,6 +456,15 @@ class PixtralForConditionalGeneration( return self.language_model.get_eagle3_aux_hidden_state_layers() def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + _vision_encoder_stacked_params = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith(("vision_encoder", "vision_tower")) @@ -449,7 +479,6 @@ class PixtralForConditionalGeneration( def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): return weight[0].startswith("pre_mm_projector_norm") - # Get references to parameters for direct loading vision_encoder_dict = ( dict(self.vision_encoder.named_parameters()) if self.vision_encoder is not None @@ -472,29 +501,41 @@ class PixtralForConditionalGeneration( ) def llm_weights_generator(): - # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): if _is_layer_none_or_staged(self.vision_encoder): continue - # Load vision encoder weights directly trimmed_name = ".".join(name.split(".")[1:]) - param = vision_encoder_dict.get(trimmed_name) - if param is not None: - with torch.no_grad(): - default_weight_loader(param, w) + for ( + param_name, + weight_name, + shard_id, + ) in _vision_encoder_stacked_params: + if weight_name in trimmed_name: + trimmed_name = trimmed_name.replace(weight_name, param_name) + param = vision_encoder_dict[trimmed_name] + weight_loader = param.weight_loader + weight_loader(param, w, shard_id) + break + else: + param = vision_encoder_dict.get(trimmed_name) + if param is not None: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, w) elif is_patch_merger((name, w)): if _is_layer_none_or_staged(self.patch_merger): continue - # Load vision patch merger weights directly trimmed_name = ".".join(name.split(".")[1:]) param = patch_merger_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): if _is_layer_none_or_staged(self.pre_mm_projector_norm): continue - # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) param = pre_mm_projector_norm_dict[trimmed_name] with torch.no_grad(): @@ -502,26 +543,23 @@ class PixtralForConditionalGeneration( elif is_vision_lang_adapter_weights((name, w)): if _is_layer_none_or_staged(self.vision_language_adapter): continue - # Load vision-language adapter weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_lang_adapter_dict.get(trimmed_name) if param is not None: - with torch.no_grad(): - default_weight_loader(param, w) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, w) else: - # LLM weights: yield them to be loaded - # by language_model.load_weights - # Strip "language_model." prefix if present (HF sharded format) name = name.removeprefix("language_model.") yield (name, w) - # Now we call the language model load with the generator self.language_model.load_weights(llm_weights_generator()) def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( - language_model="language_model", - connector="vision_language_adapter", + language_model="language_model.", + connector="vision_language_adapter.", tower_model="vision_encoder", ) @@ -614,29 +652,78 @@ def apply_rotary_emb_vit( class FeedForward(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + disable_tp: bool = False, + ) -> None: super().__init__() - assert args.intermediate_size is not None - self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) - self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + disable_tp=disable_tp, + prefix=f"{prefix}.w13", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=disable_tp, + prefix=f"{prefix}.w2", + ) + + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x class Attention(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + disable_tp: bool = False, + ): super().__init__() self.args = args assert not args.hidden_size % args.num_attention_heads - self.n_heads = args.num_attention_heads self.head_dim = args.hidden_size // args.num_attention_heads - self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) - self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) - self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) - self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.qkv_proj = QKVParallelLinear( + hidden_size=args.hidden_size, + head_size=self.head_dim, + total_num_heads=args.num_attention_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wqkv", + disable_tp=disable_tp, + ) + self.o_proj = RowParallelLinear( + input_size=args.hidden_size, + output_size=args.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wo", + disable_tp=disable_tp, + ) + + tp_size = 1 if disable_tp else get_tensor_model_parallel_world_size() + self.n_heads = divide(args.num_attention_heads, tp_size) def forward( self, @@ -646,7 +733,8 @@ class Attention(nn.Module): ) -> torch.Tensor: batch, patches, _ = x.shape - q, k, v = self.wq(x), self.wk(x), self.wv(x) + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) q = q.reshape(batch, patches, self.n_heads, self.head_dim) k = k.reshape(batch, patches, self.n_heads, self.head_dim) v = v.reshape(batch, patches, self.n_heads, self.head_dim) @@ -663,14 +751,32 @@ class Attention(nn.Module): out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) - return self.wo(out) + out, _ = self.o_proj(out) + return out class TransformerBlock(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + disable_tp: bool = False, + ): super().__init__() - self.attention = Attention(args) - self.feed_forward = FeedForward(args) + self.attention = Attention( + args, + quant_config=quant_config, + prefix=f"{prefix}.attention", + disable_tp=disable_tp, + ) + self.feed_forward = FeedForward( + args.hidden_size, + args.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + disable_tp=disable_tp, + ) self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) @@ -690,11 +796,24 @@ class TransformerBlock(nn.Module): class Transformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + disable_tp: bool = False, + ): super().__init__() self.layers = torch.nn.ModuleList() - for _ in range(args.num_hidden_layers): - self.layers.append(TransformerBlock(args)) + for idx in range(args.num_hidden_layers): + self.layers.append( + TransformerBlock( + args, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + disable_tp=disable_tp, + ) + ) def forward( self, @@ -727,9 +846,15 @@ def position_meshgrid( class VisionTransformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): super().__init__() self.args = args + disable_tp = is_vit_use_data_parallel() self.patch_conv = Conv2dLayer( in_channels=args.num_channels, out_channels=args.hidden_size, @@ -738,7 +863,12 @@ class VisionTransformer(nn.Module): bias=False, ) self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) - self.transformer = Transformer(args) + self.transformer = Transformer( + args, + quant_config=quant_config, + prefix=f"{prefix}.transformer", + disable_tp=disable_tp, + ) head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" @@ -822,13 +952,16 @@ class VisionLanguageAdapter(nn.Module): def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) - self.w_in = nn.Linear( + self.w_in = ReplicatedLinear( args.hidden_size, dim, bias=args.adapter_bias, + return_bias=False, ) self.gelu = nn.GELU() - self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) + self.w_out = ReplicatedLinear( + dim, dim, bias=args.adapter_bias, return_bias=False + ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) @@ -852,10 +985,8 @@ class PatchMerger(nn.Module): self.spatial_merge_size = spatial_merge_size self.mlp_input_dim = mlp_input_dim - self.merging_layer = nn.Linear( - mlp_input_dim, - vision_encoder_dim, - bias=use_mlp_bias, + self.merging_layer = ReplicatedLinear( + mlp_input_dim, vision_encoder_dim, bias=use_mlp_bias, return_bias=False ) def forward(