[VLM] Move supported limits and max tokens to merged multi-modal processor (#11669)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -9,7 +9,6 @@ from transformers import BatchFeature, PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -87,8 +86,8 @@ class AriaVisionModel(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_mask: Optional[torch.BoolTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]:
|
||||
pixel_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
|
||||
vit_oup = self.vision_model(
|
||||
@@ -100,7 +99,8 @@ class AriaVisionModel(nn.Module):
|
||||
|
||||
return vit_oup, image_atts
|
||||
|
||||
def _create_patch_attention_mask(self, pixel_mask):
|
||||
def _create_patch_attention_mask(
|
||||
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if pixel_mask is None:
|
||||
return None
|
||||
|
||||
@@ -115,7 +115,8 @@ class AriaVisionModel(nn.Module):
|
||||
)
|
||||
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
def _create_image_attention_mask(self, patch_attention_mask):
|
||||
def _create_image_attention_mask(
|
||||
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if patch_attention_mask is None:
|
||||
return None
|
||||
|
||||
@@ -125,13 +126,13 @@ class AriaVisionModel(nn.Module):
|
||||
|
||||
class FFN(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, ff_dim, output_dim):
|
||||
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
|
||||
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
|
||||
self.act = get_act_fn("gelu_new")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.linear_in(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.linear_out(hidden_states)
|
||||
@@ -140,7 +141,7 @@ class FFN(nn.Module):
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
|
||||
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
@@ -149,12 +150,16 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.linear = nn.Linear(embed_dim, embed_dim)
|
||||
self.dropout = nn.Dropout(drop_out_rate)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.ln_kv = nn.LayerNorm(kv_dim)
|
||||
|
||||
def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
|
||||
|
||||
@@ -169,11 +174,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
attn_output = attn_output.permute(1, 0, 2)
|
||||
|
||||
if add_residual:
|
||||
attn_output = hidden_states + self.dropout(
|
||||
self.linear(attn_output))
|
||||
else:
|
||||
attn_output = self.dropout(self.linear(attn_output))
|
||||
attn_output = self.linear(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
@@ -201,14 +202,14 @@ class AriaProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_to_query_dict,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim,
|
||||
ff_dim,
|
||||
output_dim,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
patch_to_query_dict: dict[int, int],
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: int,
|
||||
ff_dim: int,
|
||||
output_dim: int,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_to_query_dict = patch_to_query_dict
|
||||
self.embed_dim = embed_dim
|
||||
@@ -224,7 +225,11 @@ class AriaProjector(nn.Module):
|
||||
self.ln_ffn = norm_layer(embed_dim)
|
||||
self.ffn = FFN(embed_dim, ff_dim, output_dim)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
bs = x.shape[0]
|
||||
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
|
||||
|
||||
@@ -442,13 +447,18 @@ def build_mm_projector(config: PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
def get_max_aria_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
|
||||
class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@@ -468,13 +478,13 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
max_image_tokens = get_max_aria_image_tokens(self.ctx)
|
||||
num_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
replacement=[image_token_id] * num_image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -504,7 +514,6 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
|
||||
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
"""
|
||||
|
||||
@@ -9,7 +9,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2Processor,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@@ -18,7 +17,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@@ -398,15 +396,17 @@ class Blip2QFormerModel(nn.Module):
|
||||
return sequence_output
|
||||
|
||||
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MultiModalDataParser(max_mm_counts={"image": 1})
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> Blip2Processor:
|
||||
return self.ctx.get_hf_processor(Blip2Processor)
|
||||
@@ -427,7 +427,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
max_image_tokens = get_max_blip2_image_tokens(self.ctx)
|
||||
max_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@@ -480,7 +480,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -31,7 +30,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@@ -43,11 +41,6 @@ from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# These configs are not part of the model config but the preprocessor
|
||||
# and processor files, so we hardcode them in the model file for now.
|
||||
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
|
||||
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
|
||||
|
||||
|
||||
class ChameleonImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
@@ -55,14 +48,17 @@ class ChameleonImagePixelInputs(TypedDict):
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
def get_max_chameleon_image_tokens(ctx: InputContext):
|
||||
return CHAMELEON_IMAGE_SEQ_LENGTH
|
||||
|
||||
|
||||
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MultiModalDataParser(max_mm_counts={"image": 1})
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
processor = self._get_hf_processor()
|
||||
return processor.image_seq_length
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> ChameleonProcessor:
|
||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||
@@ -88,7 +84,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH,
|
||||
processor.image_token * self._get_num_image_tokens(),
|
||||
processor.image_end_token,
|
||||
]),
|
||||
)
|
||||
@@ -98,12 +94,15 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
config = self.ctx.get_hf_config(ChameleonConfig)
|
||||
|
||||
width = height = config.vq_config.resolution
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
|
||||
height=CHAMELEON_CROP_SIZE_HEIGHT,
|
||||
self._get_dummy_images(width=width,
|
||||
height=height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
@@ -902,7 +901,6 @@ class ChameleonModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
|
||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
@@ -931,9 +929,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
|
||||
CHAMELEON_CROP_SIZE_WIDTH)
|
||||
vq_config: ChameleonVQVAEConfig = self.config.vq_config
|
||||
expected_dims = (3, vq_config.resolution, vq_config.resolution)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
|
||||
@@ -25,7 +25,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
@@ -34,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@@ -48,9 +47,6 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
_NEWLINE_TOKEN_ID = 71019
|
||||
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
|
||||
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
|
||||
|
||||
|
||||
class FuyuImagePatchInputs(TypedDict):
|
||||
type: Literal["image_patches"]
|
||||
@@ -67,43 +63,49 @@ class FuyuImagePatchInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
def _get_fuyu_num_image_tokens(
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculate the number of image tokens needed for a given image size.
|
||||
|
||||
The expected Fuyu image prompts can be expressed as:
|
||||
|
||||
.. code-block::
|
||||
(image_token * ncols + newline_token) * nrows
|
||||
|
||||
Args:
|
||||
image_size: Tuple[int, int] - `(width, height)` of the image
|
||||
|
||||
Returns:
|
||||
ncols: int - number of image tokens in `x` direction
|
||||
nrows: int - number of image tokens in `y` direction
|
||||
"""
|
||||
ncols = math.ceil(image_width / 30)
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
def get_max_fuyu_image_tokens(ctx: InputContext):
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
|
||||
return (ncols + 1) * nrows
|
||||
|
||||
|
||||
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MultiModalDataParser(max_mm_counts={"image": 1})
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_image_target_size(self) -> ImageSize:
|
||||
processor = self._get_hf_processor()
|
||||
image_processor: FuyuImageProcessor = processor.image_processor
|
||||
|
||||
target_size = image_processor.size
|
||||
return ImageSize(width=target_size["width"],
|
||||
height=target_size["height"])
|
||||
|
||||
def _get_image_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
if not (image_width <= target_width and image_height <= target_height):
|
||||
height_scale_factor = target_height / image_height
|
||||
width_scale_factor = target_width / image_width
|
||||
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||
|
||||
image_height = int(image_height * optimal_scale_factor)
|
||||
image_width = int(image_width * optimal_scale_factor)
|
||||
|
||||
ncols = math.ceil(image_width / 30)
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
max_ncols, max_nrows = self._get_image_grid_size(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
max_image_tokens = (max_ncols + 1) * max_nrows
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
def _get_hf_processor(self) -> FuyuProcessor:
|
||||
return self.ctx.get_hf_processor(FuyuProcessor)
|
||||
@@ -166,28 +168,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
eot_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(eot_token_id, int)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor: FuyuImageProcessor = hf_processor.image_processor
|
||||
target_size = image_processor.size
|
||||
target_height, target_width = (target_size["height"],
|
||||
target_size["width"])
|
||||
|
||||
def get_replacement_fuyu(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
width, height = image_size.width, image_size.height
|
||||
if not (width <= target_width and height <= target_height):
|
||||
height_scale_factor = target_height / height
|
||||
width_scale_factor = target_width / width
|
||||
optimal_scale_factor = min(height_scale_factor,
|
||||
width_scale_factor)
|
||||
|
||||
height = int(height * optimal_scale_factor)
|
||||
width = int(width * optimal_scale_factor)
|
||||
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
ncols, nrows = self._get_image_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
|
||||
@@ -225,12 +212,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
@@ -240,7 +228,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
|
||||
@@ -119,6 +119,12 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": get_max_llava_image_tokens(self.ctx)}
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
|
||||
@@ -324,7 +330,6 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
@@ -649,7 +654,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
||||
@@ -23,7 +23,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@@ -306,25 +305,32 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
return image_features_hd_newline
|
||||
|
||||
|
||||
def get_max_phi3v_image_tokens(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> int:
|
||||
hf_processor_mm_kwargs = {}
|
||||
if num_crops:
|
||||
hf_processor_mm_kwargs["num_crops"] = num_crops
|
||||
|
||||
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
return processor.calc_num_image_tokens_from_image_size(
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
|
||||
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
processor = self._get_hf_processor()
|
||||
|
||||
return processor.calc_num_image_tokens_from_image_size( # type: ignore
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
max_image_tokens = self._get_num_image_tokens(
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
@@ -332,6 +338,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> ProcessorMixin:
|
||||
if num_crops is not None:
|
||||
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _call_hf_processor(
|
||||
@@ -375,7 +382,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
@@ -385,9 +391,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
|
||||
width=image_size.width,
|
||||
height=image_size.height,
|
||||
num_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
@@ -467,7 +473,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
|
||||
@@ -33,13 +33,12 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@@ -80,15 +79,18 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
|
||||
hf_config = ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_position = hf_config.audio_config.max_source_positions
|
||||
output_lengths = (max_source_position - 2) // 2 + 1
|
||||
return output_lengths
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_positions = hf_config.audio_config.max_source_positions
|
||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||
|
||||
return {"audio": max_output_lengths}
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
@@ -157,11 +159,21 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||
_, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1))
|
||||
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
return [placeholder] * audio_output_lengths[item_idx]
|
||||
num_placeholders = audio_output_lengths[item_idx]
|
||||
if num_placeholders == 0:
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio = audios.get(item_idx)
|
||||
raise ValueError(
|
||||
f"The audio {audio} (len={len(audio)}) is too short "
|
||||
"to be represented inside the model")
|
||||
|
||||
return [placeholder] * num_placeholders
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@@ -171,6 +183,14 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
]
|
||||
|
||||
def _always_apply_prompt_replacements(self) -> bool:
|
||||
# HF never applies prompt replacements, so we have to do it ourselves
|
||||
# _find_placeholders may incorrectly think that HF has already performed
|
||||
# processing for multi-audio input when the input audios are short
|
||||
# (the corresponding placeholders may take up fewer tokens than
|
||||
# the number of audio items)
|
||||
return True
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
@@ -192,8 +212,6 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_max_qwen2_audio_audio_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@@ -40,7 +40,6 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
@@ -650,8 +649,9 @@ def _get_vision_info(
|
||||
width: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
*,
|
||||
do_resize: bool = True,
|
||||
data_type_key: str = "image",
|
||||
modality: str = "image",
|
||||
mm_count: int = 1,
|
||||
):
|
||||
"""Get information (resized height / width and number of vision tokens)
|
||||
@@ -671,11 +671,12 @@ def _get_vision_info(
|
||||
else:
|
||||
resized_height, resized_width = height, width
|
||||
|
||||
if data_type_key == "image":
|
||||
if modality == "image":
|
||||
grid_t = mm_count
|
||||
else:
|
||||
assert data_type_key == "video"
|
||||
elif modality == "video":
|
||||
grid_t = max(mm_count // temporal_patch_size, 1)
|
||||
else:
|
||||
raise ValueError(f"Modality {modality} is not supported")
|
||||
|
||||
grid_h = resized_height // patch_size
|
||||
grid_w = resized_width // patch_size
|
||||
@@ -691,41 +692,11 @@ def _get_image_processor(hf_processor: Qwen2VLProcessor):
|
||||
return image_processor
|
||||
|
||||
|
||||
def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
|
||||
data_type_key: str,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None) -> int:
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
hf_processor = ctx.get_hf_processor(Qwen2VLProcessor)
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
_, _, max_llm_image_tokens = _get_vision_info(
|
||||
vision_config,
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
min_pixels=min_pixels or image_processor.min_pixels,
|
||||
max_pixels=max_pixels or image_processor.max_pixels,
|
||||
data_type_key=data_type_key,
|
||||
)
|
||||
return max_llm_image_tokens
|
||||
|
||||
|
||||
get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens,
|
||||
data_type_key="image")
|
||||
get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
|
||||
data_type_key="video")
|
||||
|
||||
|
||||
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
|
||||
def __init__(self, data: dict, modality: str) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.modality = modality
|
||||
super().__init__(data, modality)
|
||||
|
||||
grid_thw = data[f"{modality}_grid_thw"]
|
||||
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
@@ -734,9 +705,6 @@ class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
for i in range(len(grid_thw))
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r})")
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data[f"{self.modality}_grid_thw"])
|
||||
|
||||
@@ -792,6 +760,32 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def _get_max_mm_tokens(self, modality: str) -> int:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
_, _, max_llm_image_tokens = _get_vision_info(
|
||||
vision_config,
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
modality=modality,
|
||||
)
|
||||
return max_llm_image_tokens
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self._get_max_mm_tokens("image"),
|
||||
"video": self._get_max_mm_tokens("video"),
|
||||
}
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return Qwen2MultiModalDataParser()
|
||||
|
||||
@@ -908,9 +902,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"video", get_max_qwen2_vl_video_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""PyTorch Ultravox model."""
|
||||
|
||||
import math
|
||||
from functools import cached_property, lru_cache
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@@ -17,7 +17,6 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@@ -58,23 +57,18 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
||||
UltravoxAudioEmbeddingInputs]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
|
||||
return WhisperFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
|
||||
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
|
||||
hf_config = ctx.get_hf_config(UltravoxConfig)
|
||||
return cached_feature_extractor(hf_config.audio_model_id)
|
||||
|
||||
|
||||
def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
|
||||
_AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
return {"audio": max_audio_tokens}
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
@@ -322,8 +316,6 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_ultravox_max_audio_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
|
||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user