Files
vllm/vllm/model_executor/models/pixtral.py
Russell Bryant e489ad7a21 [Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

1126 lines
41 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass, fields
from functools import cached_property
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image
from transformers import PixtralVisionConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
max_image_size = mm_encoder.mm_config.max_image_size
image_patch_size = mm_encoder.mm_config.image_patch_size
return ((max_image_size // image_patch_size)**2)
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
image_feature_size = len(encoding.tokens)
num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_prompt_token_counts(
(image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens),
)
mm_data = {"image": num_images * [image]}
mm_placeholders = {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
return DummyData(seq_data, mm_data, mm_placeholders)
def input_mapper_for_pixtral(ctx: InputContext,
data: object) -> MultiModalKwargs:
"""Maps the input data to its MultiModalKwargs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing PIL images to be processed
and mapped to `images`.
Returns:
MultiModalKwargs containing the stacked normalized images tensor or
image embeddings.
"""
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
data_list = data if isinstance(data, list) else [data]
images = []
image_tokens_list = []
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
images.append(image)
image_tokens_list.append(encoding.tokens)
image_tokens = torch.tensor([
token_id for image_tokens in image_tokens_list
for token_id in image_tokens
])
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
prompt_token_ids = inputs.get("prompt_token_ids")
prompt = inputs.get("prompt")
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
image_break_id = mm_encoder.special_ids.img_break
image_end_id = mm_encoder.special_ids.img_end
if image_token_id not in inputs['prompt_token_ids']:
raise ValueError(
f"You've passed {inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
# Get precise tracking of placeholder positions
placeholder_ranges = []
curr_offset = -1
curr_length = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in (image_token_id, image_break_id):
if curr_offset < 0:
curr_offset = i
curr_length += 1
elif prompt_token_ids[i] == image_end_id:
curr_length += 1
placeholder_ranges.append(
PlaceholderRange(offset=curr_offset, length=curr_length))
curr_offset = -1
curr_length = 0
else:
pass
return token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
vision_args = {
key: value
for key, value in self.config.vision_config.to_dict().items()
if key in dataclass_fields
}
if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
# NOTE: We patch the outputs of the vision encoder with embeddings
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]
image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id,
self.vision_args.image_break_token_id,
self.vision_args.image_end_token_id,
])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for pixtral.
"""
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def _parse_and_validate_image_input(
self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None,
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
if images is None:
return None, None
if isinstance(images, torch.Tensor):
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
flatten_images.extend(imgs_per_req)
images = flatten_images
if isinstance(image_tokens, torch.Tensor):
# image_tokens are batched
image_tokens = image_tokens.flatten()
elif isinstance(image_tokens, list):
# image_tokens are of different lengths thus passed as a list
image_tokens = torch.cat(image_tokens)
assert image_tokens.dim() == 1
return images, image_tokens
def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor:
return self.vision_language_adapter(self.vision_encoder(image_input))
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder")
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")
# Get references to parameters for direct loading
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
vision_lang_adapter_dict = dict(
self.vision_language_adapter.named_parameters())
def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
yield (name, w)
# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())
# Vision encoder
@dataclass
class VisionEncoderArgs:
hidden_size: int
num_channels: int
image_size: int
patch_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
x: torch.Tensor) -> torch.Tensor:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x.ndim
assert ndim > 1
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
freqs_cis.shape,
(x.shape[1], x.shape[-1]),
)
shape = [
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
]
return freqs_cis.view(*shape)
def precompute_freqs_cis_2d(
dim: int,
height: int,
width: int,
theta: float,
) -> torch.Tensor:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
to be indexed by (height, width) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
h = torch.arange(height, device=freqs.device)
w = torch.arange(width, device=freqs.device)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat(
[
freqs_h[:, None, :].repeat(1, width, 1),
freqs_w[None, :, :].repeat(height, 1, 1),
],
dim=-1,
)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
def apply_rotary_emb_vit(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
assert freqs_cis.dtype == torch.complex64
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class FeedForward(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
assert args.intermediate_size is not None
self.w1 = nn.Linear(args.hidden_size,
args.intermediate_size,
bias=False)
self.w2 = nn.Linear(args.intermediate_size,
args.hidden_size,
bias=False)
self.w3 = nn.Linear(args.hidden_size,
args.intermediate_size,
bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Attention(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
assert not args.hidden_size % args.num_attention_heads
self.n_heads = args.num_attention_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
batch, patches, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
class TransformerBlock(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x),
mask=mask,
freqs_cis=freqs_cis)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class Transformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(TransformerBlock(args))
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
return x
def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor:
positions = torch.cat([
torch.stack(
torch.meshgrid(
torch.arange(p.shape[-2]),
torch.arange(p.shape[-1]),
indexing="ij",
),
dim=-1,
).reshape(-1, 2) for p in patch_embeds_list
])
return positions
class VisionTransformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
self.patch_conv = nn.Conv2d(
in_channels=args.num_channels,
out_channels=args.hidden_size,
kernel_size=args.patch_size,
stride=args.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
self.transformer = Transformer(args)
head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None
@property
def max_patches_per_side(self) -> int:
return self.args.image_size // self.args.patch_size
@property
def device(self) -> torch.types.Device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@property
def freqs_cis(self) -> torch.Tensor:
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis_2d(
dim=self.args.hidden_size // self.args.num_attention_heads,
height=self.max_patches_per_side,
width=self.max_patches_per_side,
theta=self.args.rope_theta,
)
if self._freqs_cis.device != self.device:
self._freqs_cis = self._freqs_cis.to(device=self.device)
return self._freqs_cis
def forward(
self,
images: List[torch.Tensor],
) -> torch.Tensor:
"""
Args:
images: list of N_img images of variable sizes,
each of shape (C, H, W)
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
]
# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
positions = position_meshgrid(patch_embeds_list).to(self.device)
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
# pass through Transformer with a block diagonal mask delimiting images
if USE_XFORMERS_OPS:
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
else:
raise ImportError("Xformers is required for Pixtral inference "
"with the Mistral format")
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# remove batch dimension of the single sequence
return out.squeeze(0)
class VisionLanguageAdapter(nn.Module):
def __init__(self, args: VisionEncoderArgs, dim: int):
super().__init__()
assert isinstance(args, VisionEncoderArgs)
self.w_in = nn.Linear(
args.hidden_size,
dim,
bias=args.adapter_bias,
)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.
def get_pixtral_hf_patch_grid_length(*, image_size: int,
patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size
def get_pixtral_hf_image_feature_size(
*,
image_size: int,
patch_size: int,
) -> int:
grid_length = get_pixtral_hf_patch_grid_length(
image_size=image_size,
patch_size=patch_size,
)
# Consider the image_break_token
return (grid_length + 1) * grid_length
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
grid_length = get_pixtral_hf_patch_grid_length(
image_size=hf_config.image_size,
patch_size=hf_config.patch_size,
)
# Consider the image_break_token
return (grid_length + 1) * grid_length
def dummy_image_for_pixtral_hf(
hf_config: PixtralVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
def get_pixtral_hf_image_feature_grid_size(
hf_config: PixtralVisionConfig,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
max_width = max_height = hf_config.image_size
patch_width = patch_height = hf_config.patch_size
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
image_width = int(math.ceil(image_width / ratio))
image_height = int(math.ceil(image_height / ratio))
nrows, ncols = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
(patch_height, patch_width),
) # type: ignore
return ncols, nrows
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
return get_pixtral_hf_image_feature_size(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_max_image_tokens(self) -> int:
return get_max_pixtral_hf_image_tokens(self.vision_config)
def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_pixtral_hf_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
class PixtralHFMLP(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
assert config.intermediate_size is not None
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_and_mul(gate_up)
x, _ = self.down_proj(x)
return x
class PixtralHFAttention(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.total_num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.n_heads = divide(config.num_attention_heads, tp_size)
self.head_dim = config.hidden_size // config.num_attention_heads
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
assert self.total_num_heads * self.head_dim == config.hidden_size
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
q, k, v = qkv_states.chunk(3, dim=-1)
# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, patches, self.n_heads, self.head_dim)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
if USE_XFORMERS_OPS:
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
out = xops.memory_efficient_attention(q,
k,
v,
attn_bias=attention_mask)
else:
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2)
out = out.view(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out)
return attn_output, None
class PixtralHFTransformerBlock(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.feed_forward = PixtralHFMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r, _ = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class PixtralHFTransformer(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
PixtralHFTransformerBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
return_all_hidden_states: bool,
) -> torch.Tensor:
hidden_states_pool = []
for layer in self.layers:
x = layer(x, attention_mask, position_embeddings)
if return_all_hidden_states:
hidden_states_pool.append(x)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return x
class PixtralHFVisionModel(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} "
"layers.")
if require_post_norm is True:
msg = "PixtralHFVisionModel does not have post-layernorm"
raise ValueError(msg)
self.dtype = next(self.parameters()).dtype
self.device = next(self.parameters()).device
self.patch_positional_embedding = PixtralRotaryEmbedding(
config, self.device)
def forward(
self,
pixel_values: List[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
"""
Args:
pixel_values: Each image to be processed will be a separate tensor
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
feature_sample_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(self.dtype))
for img in pixel_values
]
# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size).to(
self.device)
position_embedding = self.patch_positional_embedding(
patch_embeds, position_ids)
if USE_XFORMERS_OPS:
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
else:
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
return_all_hidden_states = feature_sample_layers is not None
out = self.transformer(
patch_embeds,
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states)
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)
return out
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.transformer.layers)
for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if name.startswith("transformer.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params