Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,17 +10,20 @@ from typing import Annotated, Literal, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, PixtralVisionConfig, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
from transformers.models.pixtral.modeling_pixtral import (
|
||||
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
||||
PixtralRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
position_ids_in_meshgrid,
|
||||
)
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -28,37 +31,50 @@ from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalUUIDDict, NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalUUIDDict,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
|
||||
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
|
||||
resolve_visual_encoder_outputs)
|
||||
from .vision import (
|
||||
VisionEncoderInfo,
|
||||
VisionFeatureSelectStrategy,
|
||||
resolve_visual_encoder_outputs,
|
||||
)
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)):
|
||||
|
||||
if current_platform.is_cuda() and current_platform.has_device_capability(100):
|
||||
# Xformers FA is not compatible with B200
|
||||
USE_XFORMERS_OPS = False
|
||||
else:
|
||||
@@ -76,13 +92,16 @@ class PixtralImagePixelInputs(TensorSchema):
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
|
||||
|
||||
The result of stacking `ImageEncoding.tokens` from each prompt.
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})]
|
||||
images: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
|
||||
]
|
||||
|
||||
|
||||
class PixtralProcessorAdapter:
|
||||
@@ -150,7 +169,8 @@ class PixtralProcessorAdapter:
|
||||
"Make sure to process your input via `mistral_common`'s "
|
||||
"tokenizer or pass a chat completion request. "
|
||||
"For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||
"https://github.com/vllm-project/vllm/issues/8411."
|
||||
)
|
||||
|
||||
images_processed = list[torch.Tensor]()
|
||||
images_tokens = list[torch.Tensor]()
|
||||
@@ -163,16 +183,15 @@ class PixtralProcessorAdapter:
|
||||
images_processed.append(image_processed)
|
||||
images_tokens.append(image_tokens)
|
||||
|
||||
return BatchFeature({
|
||||
"input_ids":
|
||||
torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images":
|
||||
images_processed,
|
||||
})
|
||||
return BatchFeature(
|
||||
{
|
||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images": images_processed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self) -> MistralTokenizer:
|
||||
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
||||
if not isinstance(tokenizer, MistralTokenizer):
|
||||
@@ -209,7 +228,8 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||
Image.new("RGB", (image_width, image_height)))
|
||||
Image.new("RGB", (image_width, image_height))
|
||||
)
|
||||
|
||||
return ncols * nrows
|
||||
|
||||
@@ -221,7 +241,6 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
|
||||
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
@@ -233,17 +252,17 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
@@ -259,23 +278,27 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
dummy_images = dummy_mm_data.get("image", [])
|
||||
tokenization_kwargs = {"truncation": False}
|
||||
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[
|
||||
TextChunk(text=dummy_text),
|
||||
*(ImageChunk(image=image) for image in dummy_images),
|
||||
]),
|
||||
])
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[
|
||||
TextChunk(text=dummy_text),
|
||||
*(ImageChunk(image=image) for image in dummy_images),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
dummy_tokens = res.tokens
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens,
|
||||
mm_data=dummy_mm_data,
|
||||
tokenization_kwargs=tokenization_kwargs)
|
||||
return ProcessorInputs(
|
||||
prompt=dummy_tokens,
|
||||
mm_data=dummy_mm_data,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
):
|
||||
|
||||
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]):
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
@@ -300,7 +323,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||
Image.new("RGB", (image_size.width, image_size.height)))
|
||||
Image.new("RGB", (image_size.width, image_size.height))
|
||||
)
|
||||
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
tokens[-1] = image_end_id
|
||||
@@ -335,12 +359,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
return prompt_ids, mm_info, True
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
|
||||
info=PixtralProcessingInfo,
|
||||
dummy_inputs=PixtralDummyInputsBuilder)
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PixtralMultiModalProcessor,
|
||||
info=PixtralProcessingInfo,
|
||||
dummy_inputs=PixtralDummyInputsBuilder,
|
||||
)
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
@@ -374,8 +398,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
||||
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
|
||||
eps=1e-5)
|
||||
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5)
|
||||
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
||||
self.patch_merger = PatchMerger(
|
||||
@@ -385,13 +408,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
self.vision_language_adapter = VisionLanguageAdapter(
|
||||
self.vision_args, dim=config.text_config.hidden_size)
|
||||
self.vision_args, dim=config.text_config.hidden_size
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[PixtralImagePixelInputs]:
|
||||
images = kwargs.pop("images", None)
|
||||
if images is None:
|
||||
return None
|
||||
@@ -407,23 +433,24 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
images = image_input["images"]
|
||||
image_features = self.vision_encoder(images)
|
||||
feature_sizes = [
|
||||
image_feature.shape[0] for image_feature in image_features
|
||||
]
|
||||
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
|
||||
image_features = torch.cat(image_features)
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
||||
image_features = self.pre_mm_projector_norm(image_features)
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
||||
patch_size = self.vision_args.patch_size
|
||||
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
|
||||
img_patch_dims = [(img.shape[1] // patch_size,
|
||||
img.shape[2] // patch_size) for img in images]
|
||||
img_patch_dims = [
|
||||
(img.shape[1] // patch_size, img.shape[2] // patch_size)
|
||||
for img in images
|
||||
]
|
||||
feature_sizes = [
|
||||
feature_size // spatial_merge_size_square
|
||||
for feature_size in feature_sizes
|
||||
]
|
||||
image_features = self.patch_merger(image_features,
|
||||
image_sizes=img_patch_dims)
|
||||
image_features = self.patch_merger(
|
||||
image_features, image_sizes=img_patch_dims
|
||||
)
|
||||
image_embeds = self.vision_language_adapter(image_features)
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
@@ -431,8 +458,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@@ -451,10 +477,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -465,7 +490,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith("vision_encoder")
|
||||
|
||||
@@ -480,38 +504,42 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# Get references to parameters for direct loading
|
||||
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
||||
patch_merger_dict = dict(self.patch_merger.named_parameters(
|
||||
)) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
|
||||
pre_mm_projector_norm_dict = dict(
|
||||
self.pre_mm_projector_norm.named_parameters(
|
||||
)) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
|
||||
vision_lang_adapter_dict = dict(
|
||||
self.vision_language_adapter.named_parameters())
|
||||
patch_merger_dict = (
|
||||
dict(self.patch_merger.named_parameters())
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE
|
||||
else dict()
|
||||
)
|
||||
pre_mm_projector_norm_dict = (
|
||||
dict(self.pre_mm_projector_norm.named_parameters())
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm
|
||||
else dict()
|
||||
)
|
||||
vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters())
|
||||
|
||||
def llm_weights_generator():
|
||||
# Single pass over weights
|
||||
for name, w in weights:
|
||||
if is_vision_encoder_weights((name, w)):
|
||||
# Load vision encoder weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = vision_encoder_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_patch_merger((name, w)):
|
||||
# Load vision patch merger weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = patch_merger_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_pre_mm_projector_norm((name, w)):
|
||||
# Load vision pre_mm_projector_norm weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = pre_mm_projector_norm_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_vision_lang_adapter_weights((name, w)):
|
||||
# Load vision-language adapter weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = vision_lang_adapter_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
@@ -542,8 +570,7 @@ class VisionEncoderArgs:
|
||||
mm_projector_id: str = ""
|
||||
|
||||
|
||||
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
|
||||
x: torch.Tensor) -> torch.Tensor:
|
||||
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
freqs_cis: complex - (seq_len, head_dim / 2)
|
||||
x: complex - (bsz, seq_len, head_dim / 2)
|
||||
@@ -554,9 +581,7 @@ def _reshape_for_broadcast(freqs_cis: torch.Tensor,
|
||||
freqs_cis.shape,
|
||||
(x.shape[1], x.shape[-1]),
|
||||
)
|
||||
shape = [
|
||||
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
|
||||
]
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
@@ -571,7 +596,7 @@ def precompute_freqs_cis_2d(
|
||||
to be indexed by (height, width) position tuples
|
||||
"""
|
||||
# (dim / 2) frequency bases
|
||||
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
||||
|
||||
h = torch.arange(height, device=freqs.device)
|
||||
w = torch.arange(width, device=freqs.device)
|
||||
@@ -603,26 +628,18 @@ def apply_rotary_emb_vit(
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
assert args.intermediate_size is not None
|
||||
self.w1 = nn.Linear(args.hidden_size,
|
||||
args.intermediate_size,
|
||||
bias=False)
|
||||
self.w2 = nn.Linear(args.intermediate_size,
|
||||
args.hidden_size,
|
||||
bias=False)
|
||||
self.w3 = nn.Linear(args.hidden_size,
|
||||
args.intermediate_size,
|
||||
bias=False)
|
||||
self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
||||
self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
@@ -656,10 +673,7 @@ class Attention(nn.Module):
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = nn.functional.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=mask)
|
||||
out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||
@@ -667,7 +681,6 @@ class Attention(nn.Module):
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
self.attention = Attention(args)
|
||||
@@ -681,9 +694,9 @@ class TransformerBlock(nn.Module):
|
||||
mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
r = self.attention.forward(self.attention_norm(x),
|
||||
mask=mask,
|
||||
freqs_cis=freqs_cis)
|
||||
r = self.attention.forward(
|
||||
self.attention_norm(x), mask=mask, freqs_cis=freqs_cis
|
||||
)
|
||||
h = x + r
|
||||
r = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
@@ -691,7 +704,6 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
@@ -709,22 +721,26 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
|
||||
positions = torch.cat([
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(p.shape[-2]),
|
||||
torch.arange(p.shape[-1]),
|
||||
indexing="ij",
|
||||
),
|
||||
dim=-1,
|
||||
).reshape(-1, 2) for p in patch_embeds_list
|
||||
])
|
||||
def position_meshgrid(
|
||||
patch_embeds_list: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
positions = torch.cat(
|
||||
[
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(p.shape[-2]),
|
||||
torch.arange(p.shape[-1]),
|
||||
indexing="ij",
|
||||
),
|
||||
dim=-1,
|
||||
).reshape(-1, 2)
|
||||
for p in patch_embeds_list
|
||||
]
|
||||
)
|
||||
return positions
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
@@ -786,9 +802,7 @@ class VisionTransformer(nn.Module):
|
||||
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
|
||||
]
|
||||
|
||||
patch_embeds = [
|
||||
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
|
||||
]
|
||||
patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
|
||||
embed_sizes = [p.shape[1] for p in patch_embeds]
|
||||
|
||||
# flatten to a single sequence
|
||||
@@ -802,13 +816,16 @@ class VisionTransformer(nn.Module):
|
||||
# pass through Transformer with a block diagonal mask delimiting images
|
||||
if USE_XFORMERS_OPS:
|
||||
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||
)
|
||||
else:
|
||||
from transformers.models.pixtral.modeling_pixtral import (
|
||||
generate_block_attention_mask)
|
||||
generate_block_attention_mask,
|
||||
)
|
||||
|
||||
mask = generate_block_attention_mask(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||
patch_embeds)
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
||||
)
|
||||
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
||||
|
||||
# squeeze dim 0 and split into separate tensors for each image
|
||||
@@ -816,7 +833,6 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
|
||||
class VisionLanguageAdapter(nn.Module):
|
||||
|
||||
def __init__(self, args: VisionEncoderArgs, dim: int):
|
||||
super().__init__()
|
||||
assert isinstance(args, VisionEncoderArgs)
|
||||
@@ -856,8 +872,9 @@ class PatchMerger(nn.Module):
|
||||
bias=use_mlp_bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor,
|
||||
image_sizes: list[tuple[int, int]]) -> torch.Tensor:
|
||||
def forward(
|
||||
self, x: torch.Tensor, image_sizes: list[tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
# image_sizes specified in tokens
|
||||
assert sum([h * w for h, w in image_sizes]) == len(x)
|
||||
|
||||
@@ -889,15 +906,14 @@ class PatchMerger(nn.Module):
|
||||
"""
|
||||
|
||||
sub_grids = get_sub_grids(
|
||||
x=x,
|
||||
image_sizes=image_sizes,
|
||||
spatial_merge_size=self.spatial_merge_size
|
||||
x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size
|
||||
) # list of [d x sub_grid_size x sub_grid_size x n_patches]
|
||||
permuted_tensor: list[torch.Tensor] = []
|
||||
for grid in sub_grids:
|
||||
n_patches = grid.shape[-1]
|
||||
permuted_tensor.append(grid.view(-1, n_patches).t(
|
||||
)) # n_patches x d * sub_grid_size * sub_grid_size
|
||||
permuted_tensor.append(
|
||||
grid.view(-1, n_patches).t()
|
||||
) # n_patches x d * sub_grid_size * sub_grid_size
|
||||
return torch.cat(
|
||||
permuted_tensor, dim=0
|
||||
) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
|
||||
@@ -917,14 +933,15 @@ def get_sub_grids(
|
||||
for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
|
||||
# Reshape image_tokens into a 2D grid
|
||||
h, w = image_sizes[image_index]
|
||||
image_grid = image_tokens.view(h, w, d).permute(
|
||||
2, 0, 1)[None, :, :, :] # 1 x d x h x w
|
||||
sub_grids = torch.nn.functional.unfold(image_grid,
|
||||
kernel_size=sub_grid_size,
|
||||
stride=sub_grid_size)
|
||||
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[
|
||||
None, :, :, :
|
||||
] # 1 x d x h x w
|
||||
sub_grids = torch.nn.functional.unfold(
|
||||
image_grid, kernel_size=sub_grid_size, stride=sub_grid_size
|
||||
)
|
||||
sub_grids = sub_grids.view(
|
||||
1, d, sub_grid_size, sub_grid_size,
|
||||
-1) # 1 x d x sub_grid_size x sub_grid_size x n_patches
|
||||
1, d, sub_grid_size, sub_grid_size, -1
|
||||
) # 1 x d x sub_grid_size x sub_grid_size x n_patches
|
||||
|
||||
all_img_sub_grids.append(sub_grids[0])
|
||||
|
||||
@@ -940,7 +957,6 @@ def get_sub_grids(
|
||||
|
||||
|
||||
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
@@ -993,7 +1009,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
@@ -1009,12 +1024,15 @@ class PixtralHFMLP(nn.Module):
|
||||
output_sizes=[config.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1025,7 +1043,6 @@ class PixtralHFMLP(nn.Module):
|
||||
|
||||
|
||||
class PixtralHFAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
@@ -1081,14 +1098,12 @@ class PixtralHFAttention(nn.Module):
|
||||
# Transpose q and k back for attention
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
out = xops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=attention_mask)
|
||||
out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask)
|
||||
else:
|
||||
v = v.transpose(1, 2)
|
||||
out = nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attention_mask)
|
||||
q, k, v, attn_mask=attention_mask
|
||||
)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(batch, patches, self.n_heads * self.head_dim)
|
||||
@@ -1098,7 +1113,6 @@ class PixtralHFAttention(nn.Module):
|
||||
|
||||
|
||||
class PixtralHFTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
@@ -1109,12 +1123,12 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.attention = PixtralHFAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.feed_forward = PixtralHFMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward")
|
||||
self.attention = PixtralHFAttention(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.attention"
|
||||
)
|
||||
self.feed_forward = PixtralHFMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||
)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
def forward(
|
||||
@@ -1123,9 +1137,11 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
attention_mask: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
r, _ = self.attention.forward(self.attention_norm(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings)
|
||||
r, _ = self.attention.forward(
|
||||
self.attention_norm(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
h = hidden_states + r
|
||||
r = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
@@ -1133,7 +1149,6 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class PixtralHFTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
@@ -1149,12 +1164,16 @@ class PixtralHFTransformer(nn.Module):
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
PixtralHFTransformerBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
PixtralHFTransformerBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1177,7 +1196,6 @@ class PixtralHFTransformer(nn.Module):
|
||||
|
||||
|
||||
class PixtralHFVisionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
@@ -1211,7 +1229,8 @@ class PixtralHFVisionModel(nn.Module):
|
||||
raise ValueError(
|
||||
f"The original encoder only has {num_hidden_layers} "
|
||||
f"layers, but you requested {len(self.transformer.layers)} "
|
||||
"layers.")
|
||||
"layers."
|
||||
)
|
||||
|
||||
if require_post_norm is True:
|
||||
msg = "PixtralHFVisionModel does not have post-layernorm"
|
||||
@@ -1219,8 +1238,7 @@ class PixtralHFVisionModel(nn.Module):
|
||||
|
||||
self.dtype = next(self.parameters()).dtype
|
||||
self.device = next(self.parameters()).device
|
||||
self.patch_positional_embedding = PixtralRotaryEmbedding(
|
||||
config, self.device)
|
||||
self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1245,13 +1263,10 @@ class PixtralHFVisionModel(nn.Module):
|
||||
"""
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [
|
||||
self.patch_conv(img.unsqueeze(0).to(self.dtype))
|
||||
for img in pixel_values
|
||||
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values
|
||||
]
|
||||
|
||||
patch_embeds = [
|
||||
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
|
||||
]
|
||||
patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
|
||||
embed_sizes = [p.shape[1] for p in patch_embeds]
|
||||
|
||||
# flatten to a single sequence
|
||||
@@ -1261,20 +1276,22 @@ class PixtralHFVisionModel(nn.Module):
|
||||
# positional embeddings
|
||||
position_ids = position_ids_in_meshgrid(
|
||||
patch_embeds_list,
|
||||
max_width=self.config.image_size // self.config.patch_size).to(
|
||||
self.device)
|
||||
position_embedding = self.patch_positional_embedding(
|
||||
patch_embeds, position_ids)
|
||||
max_width=self.config.image_size // self.config.patch_size,
|
||||
).to(self.device)
|
||||
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
||||
|
||||
if USE_XFORMERS_OPS:
|
||||
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||
)
|
||||
else:
|
||||
from transformers.models.pixtral.modeling_pixtral import (
|
||||
generate_block_attention_mask)
|
||||
generate_block_attention_mask,
|
||||
)
|
||||
|
||||
attention_mask = generate_block_attention_mask(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||
patch_embeds)
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
||||
)
|
||||
|
||||
out = self.transformer(
|
||||
patch_embeds,
|
||||
@@ -1296,8 +1313,7 @@ class PixtralHFVisionModel(nn.Module):
|
||||
|
||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -1317,7 +1333,7 @@ class PixtralHFVisionModel(nn.Module):
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -1327,8 +1343,7 @@ class PixtralHFVisionModel(nn.Module):
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user