- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
1126 lines
41 KiB
Python
1126 lines
41 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import math
|
|
from dataclasses import dataclass, fields
|
|
from functools import cached_property
|
|
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mistral_common.protocol.instruct.messages import ImageChunk
|
|
from PIL import Image
|
|
from transformers import PixtralVisionConfig
|
|
from transformers.models.pixtral.image_processing_pixtral import (
|
|
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
|
from transformers.models.pixtral.modeling_pixtral import (
|
|
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
|
|
|
from vllm.attention import AttentionMetadata
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
|
InputContext, token_inputs)
|
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
|
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
|
consecutive_placeholder_ranges)
|
|
from vllm.sequence import IntermediateTensors, SequenceData
|
|
|
|
from .interfaces import SupportsMultiModal, SupportsPP
|
|
from .utils import (init_vllm_registered_model, maybe_prefix,
|
|
merge_multimodal_embeddings)
|
|
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
|
|
|
try:
|
|
from xformers import ops as xops
|
|
USE_XFORMERS_OPS = True
|
|
except ImportError:
|
|
USE_XFORMERS_OPS = False
|
|
|
|
|
|
def get_max_pixtral_image_tokens(ctx: InputContext):
|
|
tokenizer = cached_get_tokenizer(
|
|
ctx.model_config.tokenizer,
|
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
|
mm_encoder = tokenizer.instruct.mm_encoder
|
|
|
|
max_image_size = mm_encoder.mm_config.max_image_size
|
|
image_patch_size = mm_encoder.mm_config.image_patch_size
|
|
|
|
return ((max_image_size // image_patch_size)**2)
|
|
|
|
|
|
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
|
mm_counts: Mapping[str, int]):
|
|
tokenizer = cached_get_tokenizer(
|
|
ctx.model_config.tokenizer,
|
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
|
|
|
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
|
image_token_id = mm_encoder.special_ids.img
|
|
|
|
mm_config = ctx.get_mm_config()
|
|
num_images = mm_config.limit_per_prompt.get("image", 1)
|
|
|
|
# dummy size
|
|
size = 256
|
|
image = Image.new("RGB", (size, size), color=0)
|
|
|
|
encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
|
|
image_feature_size = len(encoding.tokens)
|
|
num_image_tokens = image_feature_size * num_images
|
|
seq_data = SequenceData.from_prompt_token_counts(
|
|
(image_token_id, num_image_tokens),
|
|
(0, seq_len - num_image_tokens),
|
|
)
|
|
|
|
mm_data = {"image": num_images * [image]}
|
|
mm_placeholders = {
|
|
"image":
|
|
consecutive_placeholder_ranges(num_items=num_images,
|
|
item_size=image_feature_size)
|
|
}
|
|
return DummyData(seq_data, mm_data, mm_placeholders)
|
|
|
|
|
|
def input_mapper_for_pixtral(ctx: InputContext,
|
|
data: object) -> MultiModalKwargs:
|
|
"""Maps the input data to its MultiModalKwargs (if any).
|
|
|
|
Args:
|
|
ctx: Context of the loaded model.
|
|
data: data potentially containing PIL images to be processed
|
|
and mapped to `images`.
|
|
|
|
Returns:
|
|
MultiModalKwargs containing the stacked normalized images tensor or
|
|
image embeddings.
|
|
"""
|
|
model_config = ctx.model_config
|
|
tokenizer = cached_get_tokenizer(
|
|
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
|
|
|
|
data_list = data if isinstance(data, list) else [data]
|
|
|
|
images = []
|
|
image_tokens_list = []
|
|
for image_data in data_list:
|
|
image = ImageChunk(image=image_data)
|
|
encoding = tokenizer.instruct.mm_encoder(image)
|
|
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
|
|
images.append(image)
|
|
image_tokens_list.append(encoding.tokens)
|
|
|
|
image_tokens = torch.tensor([
|
|
token_id for image_tokens in image_tokens_list
|
|
for token_id in image_tokens
|
|
])
|
|
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
|
|
|
|
|
|
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|
multi_modal_data = inputs.get("multi_modal_data")
|
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
|
return inputs
|
|
|
|
prompt_token_ids = inputs.get("prompt_token_ids")
|
|
prompt = inputs.get("prompt")
|
|
tokenizer = cached_get_tokenizer(
|
|
ctx.model_config.tokenizer,
|
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
|
|
|
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
|
image_token_id = mm_encoder.special_ids.img
|
|
image_break_id = mm_encoder.special_ids.img_break
|
|
image_end_id = mm_encoder.special_ids.img_end
|
|
|
|
if image_token_id not in inputs['prompt_token_ids']:
|
|
raise ValueError(
|
|
f"You've passed {inputs=} without {image_token_id=}"
|
|
" Make sure to process your input via mistral_common's"
|
|
" tokenizer or pass a chat completion request. For more"
|
|
" For more info, see: "
|
|
"https://github.com/vllm-project/vllm/issues/8411.")
|
|
|
|
# Get precise tracking of placeholder positions
|
|
placeholder_ranges = []
|
|
curr_offset = -1
|
|
curr_length = 0
|
|
for i in range(len(prompt_token_ids)):
|
|
if prompt_token_ids[i] in (image_token_id, image_break_id):
|
|
if curr_offset < 0:
|
|
curr_offset = i
|
|
curr_length += 1
|
|
elif prompt_token_ids[i] == image_end_id:
|
|
curr_length += 1
|
|
placeholder_ranges.append(
|
|
PlaceholderRange(offset=curr_offset, length=curr_length))
|
|
curr_offset = -1
|
|
curr_length = 0
|
|
else:
|
|
pass
|
|
return token_inputs(prompt=prompt,
|
|
prompt_token_ids=prompt_token_ids,
|
|
multi_modal_data=multi_modal_data,
|
|
multi_modal_placeholders={"image": placeholder_ranges})
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
|
|
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
SupportsPP):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
|
|
vision_args = {
|
|
key: value
|
|
for key, value in self.config.vision_config.to_dict().items()
|
|
if key in dataclass_fields
|
|
}
|
|
|
|
if not ("image_break_token_id" in vision_args
|
|
and "image_end_token_id" in vision_args):
|
|
raise ValueError(
|
|
"'image_break_token_id' and 'image_end_token_id' not found "
|
|
"in the vision_encoder arguments. Please download the latest "
|
|
"version of 'params.json' from the model repository.")
|
|
|
|
self.vision_args = VisionEncoderArgs(**vision_args)
|
|
|
|
# init MistralForCausalLM
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
self.vision_encoder = VisionTransformer(self.vision_args)
|
|
self.vision_language_adapter = VisionLanguageAdapter(
|
|
self.vision_args, dim=config.text_config.hidden_size)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
@cached_property
|
|
def sampler(self):
|
|
if hasattr(self.language_model, "sampler"):
|
|
return self.language_model.sampler
|
|
|
|
return get_sampler()
|
|
|
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
|
image_input, image_tokens = self._parse_and_validate_image_input(
|
|
**kwargs)
|
|
if image_input is None:
|
|
return None
|
|
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
|
|
# NOTE: We patch the outputs of the vision encoder with embeddings
|
|
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
|
|
image_embeds = self.language_model.get_input_embeddings(image_tokens)
|
|
image_token_mask = image_tokens == self.vision_args.image_token_id
|
|
image_embeds[image_token_mask] = vision_embeddings
|
|
|
|
# NOTE: Image embeddings are split into separate tensors for each image
|
|
# by the indices of `[IMG_END]` token.
|
|
image_end_mask = image_tokens == self.vision_args.image_end_token_id
|
|
split_indices = torch.where(image_end_mask)[0] + 1
|
|
if len(split_indices) <= 1:
|
|
# Do not split, return as tensor of shape [1, fs, hs]
|
|
return image_embeds.unsqueeze(0)
|
|
|
|
# If the last split index is the last index in image_tokens, we
|
|
# ignore it to avoid empty split tensor
|
|
if split_indices[-1] == len(image_tokens):
|
|
split_indices = split_indices[:-1]
|
|
|
|
image_embeds = image_embeds.tensor_split(split_indices.cpu())
|
|
return image_embeds
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
if multimodal_embeddings is not None:
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids, inputs_embeds, multimodal_embeddings, [
|
|
self.vision_args.image_token_id,
|
|
self.vision_args.image_break_token_id,
|
|
self.vision_args.image_end_token_id,
|
|
])
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
"""Run forward pass for pixtral.
|
|
"""
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
# condition is for v0 compatibility.
|
|
elif inputs_embeds is None:
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
vision_embeddings)
|
|
input_ids = None
|
|
|
|
hidden_states = self.language_model.model(input_ids,
|
|
positions,
|
|
kv_caches,
|
|
attn_metadata,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def _parse_and_validate_image_input(
|
|
self,
|
|
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
|
|
torch.Tensor]] = None,
|
|
image_tokens: Optional[torch.Tensor] = None,
|
|
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
|
|
if images is None:
|
|
return None, None
|
|
|
|
if isinstance(images, torch.Tensor):
|
|
# if passed as batch take all images
|
|
N, B, C, W, H = images.shape
|
|
images = images.reshape(N * B, C, W, H)
|
|
images = [images[i] for i in range(images.size(0))]
|
|
elif isinstance(images, list):
|
|
# if passed as list flatten lists of tensors
|
|
flatten_images = []
|
|
for imgs_per_req in images:
|
|
imgs_per_req = [
|
|
imgs_per_req[i] for i in range(imgs_per_req.size(0))
|
|
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
|
|
|
|
flatten_images.extend(imgs_per_req)
|
|
|
|
images = flatten_images
|
|
|
|
if isinstance(image_tokens, torch.Tensor):
|
|
# image_tokens are batched
|
|
image_tokens = image_tokens.flatten()
|
|
elif isinstance(image_tokens, list):
|
|
# image_tokens are of different lengths thus passed as a list
|
|
image_tokens = torch.cat(image_tokens)
|
|
|
|
assert image_tokens.dim() == 1
|
|
|
|
return images, image_tokens
|
|
|
|
def _process_image_input(self,
|
|
image_input: List[torch.Tensor]) -> torch.Tensor:
|
|
return self.vision_language_adapter(self.vision_encoder(image_input))
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
return self.language_model.sample(logits, sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
|
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("vision_encoder")
|
|
|
|
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("vision_language_adapter")
|
|
|
|
# Get references to parameters for direct loading
|
|
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
|
vision_lang_adapter_dict = dict(
|
|
self.vision_language_adapter.named_parameters())
|
|
|
|
def llm_weights_generator():
|
|
# Single pass over weights
|
|
for name, w in weights:
|
|
if is_vision_encoder_weights((name, w)):
|
|
# Load vision encoder weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = vision_encoder_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
elif is_vision_lang_adapter_weights((name, w)):
|
|
# Load vision-language adapter weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = vision_lang_adapter_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
else:
|
|
# LLM weights: yield them to be loaded
|
|
# by language_model.load_weights
|
|
yield (name, w)
|
|
|
|
# Now we call the language model load with the generator
|
|
self.language_model.load_weights(llm_weights_generator())
|
|
|
|
|
|
# Vision encoder
|
|
@dataclass
|
|
class VisionEncoderArgs:
|
|
hidden_size: int
|
|
num_channels: int
|
|
image_size: int
|
|
patch_size: int
|
|
intermediate_size: int
|
|
num_hidden_layers: int
|
|
num_attention_heads: int
|
|
rope_theta: float # for rope-2D
|
|
image_token_id: int
|
|
image_break_token_id: int
|
|
image_end_token_id: int
|
|
adapter_bias: bool = True
|
|
|
|
|
|
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
|
|
x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
freqs_cis: complex - (seq_len, head_dim / 2)
|
|
x: complex - (bsz, seq_len, head_dim / 2)
|
|
"""
|
|
ndim = x.ndim
|
|
assert ndim > 1
|
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
|
|
freqs_cis.shape,
|
|
(x.shape[1], x.shape[-1]),
|
|
)
|
|
shape = [
|
|
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
|
|
]
|
|
return freqs_cis.view(*shape)
|
|
|
|
|
|
def precompute_freqs_cis_2d(
|
|
dim: int,
|
|
height: int,
|
|
width: int,
|
|
theta: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
|
|
to be indexed by (height, width) position tuples
|
|
"""
|
|
# (dim / 2) frequency bases
|
|
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
|
|
|
h = torch.arange(height, device=freqs.device)
|
|
w = torch.arange(width, device=freqs.device)
|
|
|
|
freqs_h = torch.outer(h, freqs[::2]).float()
|
|
freqs_w = torch.outer(w, freqs[1::2]).float()
|
|
freqs_2d = torch.cat(
|
|
[
|
|
freqs_h[:, None, :].repeat(1, width, 1),
|
|
freqs_w[None, :, :].repeat(height, 1, 1),
|
|
],
|
|
dim=-1,
|
|
)
|
|
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
|
|
|
|
|
def apply_rotary_emb_vit(
|
|
xq: torch.Tensor,
|
|
xk: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
assert freqs_cis.dtype == torch.complex64
|
|
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
assert args.intermediate_size is not None
|
|
self.w1 = nn.Linear(args.hidden_size,
|
|
args.intermediate_size,
|
|
bias=False)
|
|
self.w2 = nn.Linear(args.intermediate_size,
|
|
args.hidden_size,
|
|
bias=False)
|
|
self.w3 = nn.Linear(args.hidden_size,
|
|
args.intermediate_size,
|
|
bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
assert not args.hidden_size % args.num_attention_heads
|
|
self.n_heads = args.num_attention_heads
|
|
self.head_dim = args.hidden_size // args.num_attention_heads
|
|
|
|
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
batch, patches, _ = x.shape
|
|
|
|
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
|
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
|
|
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
|
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
|
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
|
return self.wo(out)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.attention = Attention(args)
|
|
self.feed_forward = FeedForward(args)
|
|
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
|
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r = self.attention.forward(self.attention_norm(x),
|
|
mask=mask,
|
|
freqs_cis=freqs_cis)
|
|
h = x + r
|
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList()
|
|
for _ in range(args.num_hidden_layers):
|
|
self.layers.append(TransformerBlock(args))
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
for layer in self.layers:
|
|
x = layer(x, mask=mask, freqs_cis=freqs_cis)
|
|
return x
|
|
|
|
|
|
def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor:
|
|
positions = torch.cat([
|
|
torch.stack(
|
|
torch.meshgrid(
|
|
torch.arange(p.shape[-2]),
|
|
torch.arange(p.shape[-1]),
|
|
indexing="ij",
|
|
),
|
|
dim=-1,
|
|
).reshape(-1, 2) for p in patch_embeds_list
|
|
])
|
|
return positions
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=args.num_channels,
|
|
out_channels=args.hidden_size,
|
|
kernel_size=args.patch_size,
|
|
stride=args.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
|
|
self.transformer = Transformer(args)
|
|
|
|
head_dim = self.args.hidden_size // self.args.num_attention_heads
|
|
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
|
self._freqs_cis: Optional[torch.Tensor] = None
|
|
|
|
@property
|
|
def max_patches_per_side(self) -> int:
|
|
return self.args.image_size // self.args.patch_size
|
|
|
|
@property
|
|
def device(self) -> torch.types.Device:
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return next(self.parameters()).dtype
|
|
|
|
@property
|
|
def freqs_cis(self) -> torch.Tensor:
|
|
if self._freqs_cis is None:
|
|
self._freqs_cis = precompute_freqs_cis_2d(
|
|
dim=self.args.hidden_size // self.args.num_attention_heads,
|
|
height=self.max_patches_per_side,
|
|
width=self.max_patches_per_side,
|
|
theta=self.args.rope_theta,
|
|
)
|
|
|
|
if self._freqs_cis.device != self.device:
|
|
self._freqs_cis = self._freqs_cis.to(device=self.device)
|
|
|
|
return self._freqs_cis
|
|
|
|
def forward(
|
|
self,
|
|
images: List[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
images: list of N_img images of variable sizes,
|
|
each of shape (C, H, W)
|
|
Returns:
|
|
image_features: tensor of token features for
|
|
all tokens of all images of shape (N_toks, D)
|
|
"""
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [
|
|
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
|
|
]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat(
|
|
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
positions = position_meshgrid(patch_embeds_list).to(self.device)
|
|
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
|
|
|
# pass through Transformer with a block diagonal mask delimiting images
|
|
if USE_XFORMERS_OPS:
|
|
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
|
else:
|
|
raise ImportError("Xformers is required for Pixtral inference "
|
|
"with the Mistral format")
|
|
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
|
|
|
# remove batch dimension of the single sequence
|
|
return out.squeeze(0)
|
|
|
|
|
|
class VisionLanguageAdapter(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs, dim: int):
|
|
super().__init__()
|
|
assert isinstance(args, VisionEncoderArgs)
|
|
self.w_in = nn.Linear(
|
|
args.hidden_size,
|
|
dim,
|
|
bias=args.adapter_bias,
|
|
)
|
|
self.gelu = nn.GELU()
|
|
self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.w_out(self.gelu(self.w_in(x)))
|
|
|
|
|
|
#### HF Transformers version of Pixtral ####
|
|
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
|
|
# This model follows the Llava family, meaning image embeddings are placed
|
|
# instead of the `[IMG]` token placeholders.
|
|
# The model uses [`PixtralVisionModel`] for its vision encoder,
|
|
# and [`MistralForCausalLM`] for its language decoder.
|
|
|
|
|
|
def get_pixtral_hf_patch_grid_length(*, image_size: int,
|
|
patch_size: int) -> int:
|
|
# Since interpolation is applied, the image size need not be divisible
|
|
# assert image_size % patch_size == 0
|
|
return image_size // patch_size
|
|
|
|
|
|
def get_pixtral_hf_image_feature_size(
|
|
*,
|
|
image_size: int,
|
|
patch_size: int,
|
|
) -> int:
|
|
grid_length = get_pixtral_hf_patch_grid_length(
|
|
image_size=image_size,
|
|
patch_size=patch_size,
|
|
)
|
|
|
|
# Consider the image_break_token
|
|
return (grid_length + 1) * grid_length
|
|
|
|
|
|
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
|
|
grid_length = get_pixtral_hf_patch_grid_length(
|
|
image_size=hf_config.image_size,
|
|
patch_size=hf_config.patch_size,
|
|
)
|
|
|
|
# Consider the image_break_token
|
|
return (grid_length + 1) * grid_length
|
|
|
|
|
|
def dummy_image_for_pixtral_hf(
|
|
hf_config: PixtralVisionConfig,
|
|
num_images: int,
|
|
*,
|
|
image_width_override: Optional[int] = None,
|
|
image_height_override: Optional[int] = None,
|
|
):
|
|
width = height = hf_config.image_size
|
|
if image_width_override is not None:
|
|
width = image_width_override
|
|
if image_height_override is not None:
|
|
height = image_height_override
|
|
|
|
image = Image.new("RGB", (width, height), color=0)
|
|
return {"image": image if num_images == 1 else [image] * num_images}
|
|
|
|
|
|
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
|
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
|
|
def get_pixtral_hf_image_feature_grid_size(
|
|
hf_config: PixtralVisionConfig,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> tuple[int, int]:
|
|
max_width = max_height = hf_config.image_size
|
|
patch_width = patch_height = hf_config.patch_size
|
|
|
|
ratio = max(image_width / max_width, image_height / max_height)
|
|
|
|
if ratio > 1:
|
|
image_width = int(math.ceil(image_width / ratio))
|
|
image_height = int(math.ceil(image_height / ratio))
|
|
|
|
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
|
(image_height, image_width),
|
|
(patch_height, patch_width),
|
|
) # type: ignore
|
|
|
|
return ncols, nrows
|
|
|
|
|
|
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
return get_pixtral_hf_image_feature_size(
|
|
image_size=self.vision_config.image_size,
|
|
patch_size=self.vision_config.patch_size,
|
|
)
|
|
|
|
def get_max_image_tokens(self) -> int:
|
|
return get_max_pixtral_hf_image_tokens(self.vision_config)
|
|
|
|
def get_image_size(self) -> int:
|
|
return self.vision_config.image_size
|
|
|
|
def get_patch_size(self) -> int:
|
|
return self.vision_config.patch_size
|
|
|
|
def get_patch_grid_length(self) -> int:
|
|
return get_pixtral_hf_patch_grid_length(
|
|
image_size=self.vision_config.image_size,
|
|
patch_size=self.vision_config.patch_size,
|
|
)
|
|
|
|
|
|
class PixtralHFMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
assert config.intermediate_size is not None
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
input_size=config.hidden_size,
|
|
output_sizes=[config.intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj")
|
|
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
|
|
output_size=config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj")
|
|
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_and_mul(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class PixtralHFAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
assert not config.hidden_size % config.num_attention_heads
|
|
self.total_num_heads = config.num_attention_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.n_heads = divide(config.num_attention_heads, tp_size)
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size=config.hidden_size,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.total_num_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
assert self.total_num_heads * self.head_dim == config.hidden_size
|
|
self.o_proj = RowParallelLinear(
|
|
input_size=config.hidden_size,
|
|
output_size=config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
batch, patches, _ = hidden_states.size()
|
|
|
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv_states.chunk(3, dim=-1)
|
|
|
|
# Transpose q and k to apply HF's Rotary Position Embedding
|
|
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(batch, patches, self.n_heads, self.head_dim)
|
|
cos, sin = position_embeddings
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
|
|
|
|
if USE_XFORMERS_OPS:
|
|
# Transpose q and k back for attention
|
|
q = q.transpose(1, 2).contiguous()
|
|
k = k.transpose(1, 2).contiguous()
|
|
|
|
out = xops.memory_efficient_attention(q,
|
|
k,
|
|
v,
|
|
attn_bias=attention_mask)
|
|
else:
|
|
v = v.transpose(1, 2)
|
|
out = nn.functional.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=attention_mask)
|
|
out = out.transpose(1, 2)
|
|
|
|
out = out.view(batch, patches, self.n_heads * self.head_dim)
|
|
attn_output, _ = self.o_proj(out)
|
|
|
|
return attn_output, None
|
|
|
|
|
|
class PixtralHFTransformerBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
|
self.attention = PixtralHFAttention(config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attention")
|
|
self.feed_forward = PixtralHFMLP(config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.feed_forward")
|
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r, _ = self.attention.forward(self.attention_norm(hidden_states),
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings)
|
|
h = hidden_states + r
|
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class PixtralHFTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
num_hidden_layers_override: Optional[int] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
if num_hidden_layers_override is None:
|
|
num_hidden_layers = config.num_hidden_layers
|
|
else:
|
|
num_hidden_layers = num_hidden_layers_override
|
|
|
|
self.layers = nn.ModuleList([
|
|
PixtralHFTransformerBlock(config=config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.layers.{layer_idx}")
|
|
for layer_idx in range(num_hidden_layers)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
return_all_hidden_states: bool,
|
|
) -> torch.Tensor:
|
|
hidden_states_pool = []
|
|
|
|
for layer in self.layers:
|
|
x = layer(x, attention_mask, position_embeddings)
|
|
if return_all_hidden_states:
|
|
hidden_states_pool.append(x)
|
|
# If we have multiple feature sample layers, we return all hidden
|
|
# states in order and grab the ones we need by index.
|
|
if return_all_hidden_states:
|
|
return hidden_states_pool
|
|
return x
|
|
|
|
|
|
class PixtralHFVisionModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
num_hidden_layers_override: Optional[int] = None,
|
|
require_post_norm: Optional[bool] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=config.hidden_size,
|
|
kernel_size=config.patch_size,
|
|
stride=config.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
|
self.transformer = PixtralHFTransformer(
|
|
config,
|
|
quant_config,
|
|
num_hidden_layers_override=num_hidden_layers_override,
|
|
prefix=f"{prefix}.transformer",
|
|
)
|
|
|
|
num_hidden_layers = config.num_hidden_layers
|
|
if len(self.transformer.layers) > config.num_hidden_layers:
|
|
raise ValueError(
|
|
f"The original encoder only has {num_hidden_layers} "
|
|
f"layers, but you requested {len(self.transformer.layers)} "
|
|
"layers.")
|
|
|
|
if require_post_norm is True:
|
|
msg = "PixtralHFVisionModel does not have post-layernorm"
|
|
raise ValueError(msg)
|
|
|
|
self.dtype = next(self.parameters()).dtype
|
|
self.device = next(self.parameters()).device
|
|
self.patch_positional_embedding = PixtralRotaryEmbedding(
|
|
config, self.device)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: List[torch.Tensor],
|
|
feature_sample_layers: Optional[list[int]] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
pixel_values: Each image to be processed will be a separate tensor
|
|
in pixel_values. This means it will be a list of tensors
|
|
because multiple requests batched can have multiple images,
|
|
each with their own shape potentially
|
|
feature_sample_layers: Layer indices whose features should be
|
|
concatenated and used as the visual encoder output. If none
|
|
are provided, the last layer is used.
|
|
|
|
Returns:
|
|
image_features: tensor of token features for
|
|
all tokens of all images of shape (N_toks, D)
|
|
"""
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [
|
|
self.patch_conv(img.unsqueeze(0).to(self.dtype))
|
|
for img in pixel_values
|
|
]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat(
|
|
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
position_ids = position_ids_in_meshgrid(
|
|
patch_embeds_list,
|
|
max_width=self.config.image_size // self.config.patch_size).to(
|
|
self.device)
|
|
position_embedding = self.patch_positional_embedding(
|
|
patch_embeds, position_ids)
|
|
|
|
if USE_XFORMERS_OPS:
|
|
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
|
else:
|
|
from transformers.models.pixtral.modeling_pixtral import (
|
|
generate_block_attention_mask)
|
|
attention_mask = generate_block_attention_mask(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
|
patch_embeds)
|
|
|
|
return_all_hidden_states = feature_sample_layers is not None
|
|
out = self.transformer(
|
|
patch_embeds,
|
|
attention_mask,
|
|
position_embedding,
|
|
return_all_hidden_states=return_all_hidden_states)
|
|
|
|
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
|
|
self.config.num_hidden_layers)
|
|
|
|
return out
|
|
|
|
# (TODO) Add prefix argument for filtering out weights to be loaded
|
|
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
layer_count = len(self.transformer.layers)
|
|
|
|
for name, loaded_weight in weights:
|
|
# omit layers when num_hidden_layers_override is set
|
|
if name.startswith("transformer.layers"):
|
|
layer_idx = int(name.split(".")[2])
|
|
if layer_idx >= layer_count:
|
|
continue
|
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|