Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -15,22 +15,36 @@ from transformers.feature_extraction_utils import BatchFeature
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig,
MultiModalKwargsItems, VideoItem)
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.inputs import (
ImageItem,
ModalityData,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .keye import (BaseKeyeModule, BaseMultiModalProcessor,
KeyeBaseDummyInputsBuilder, KeyeProcessingInfo)
from .keye import (
BaseKeyeModule,
BaseMultiModalProcessor,
KeyeBaseDummyInputsBuilder,
KeyeProcessingInfo,
)
logger = init_logger(__name__)
@@ -58,8 +72,9 @@ def split_thw(grid_thw: torch.Tensor) -> torch.Tensor:
return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0)
def get_num_patches(grid_thw: torch.Tensor,
num_frames: Union[list[int], torch.Tensor]) -> list[int]:
def get_num_patches(
grid_thw: torch.Tensor, num_frames: Union[list[int], torch.Tensor]
) -> list[int]:
"""
Return num_patches per video.
@@ -73,9 +88,13 @@ def get_num_patches(grid_thw: torch.Tensor,
Examples:
>>> # Suppose there are 2 videos with a total of 3 grids
>>> grid_thw = torch.tensor([[2, 2, 2], # grid 0: 2*2*2=8 patches
... [2, 2, 2], # grid 1: 2*2*2=8 patches
... [1, 1, 1]]) # grid 2: 1*1*1=1 patches
>>> grid_thw = torch.tensor(
... [
... [2, 2, 2], # grid 0: 2*2*2=8 patches
... [2, 2, 2], # grid 1: 2*2*2=8 patches
... [1, 1, 1],
... ]
... ) # grid 2: 1*1*1=1 patches
>>> num_frames = [2, 1] # The first video contains 2 grids,
the second contains 1 grid.
>>> get_num_patches(grid_thw, num_frames)
@@ -90,11 +109,14 @@ def get_num_patches(grid_thw: torch.Tensor,
num_grids_per_frame = grid_thw.prod(dim=1)
start_idx_per_video = [0, *itertools.accumulate(num_frames)]
num_patches = [
num_grids_per_frame[start_idx_per_video[i]:start_idx_per_video[i + 1]].
sum() for i in range(len(num_frames))
num_grids_per_frame[start_idx_per_video[i] : start_idx_per_video[i + 1]].sum()
for i in range(len(num_frames))
]
return torch.stack(num_patches) if num_patches else torch.zeros(
0, dtype=grid_thw.dtype, device=grid_thw.device)
return (
torch.stack(num_patches)
if num_patches
else torch.zeros(0, dtype=grid_thw.dtype, device=grid_thw.device)
)
class KeyeVL1_5ImagePixelInputs(TensorSchema):
@@ -106,11 +128,12 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})
]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
@@ -124,13 +147,13 @@ class KeyeVL1_5ImageEmbeddingInputs(TensorSchema):
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["image_embeds"]
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
KeyeVL1_5ImageEmbeddingInputs]
KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, KeyeVL1_5ImageEmbeddingInputs]
class KeyeVL1_5VideoPixelInputs(TensorSchema):
@@ -142,10 +165,11 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})
]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
num_frames: torch.Tensor
@@ -160,18 +184,17 @@ class KeyeVL1_5VideoEmbeddingInputs(TensorSchema):
- nv: Number of videos
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["video_embeds"]
video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
num_frames: torch.Tensor
KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs,
KeyeVL1_5VideoEmbeddingInputs]
KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs, KeyeVL1_5VideoEmbeddingInputs]
class KeyeVL1_5Projector(nn.Module):
def __init__(
self,
text_config: PretrainedConfig,
@@ -184,9 +207,11 @@ class KeyeVL1_5Projector(nn.Module):
self.vision_config = vision_config
self.merge_kernel_size = (2, 2)
self.hidden_size = (self.vision_config.hidden_size *
self.merge_kernel_size[0] *
self.merge_kernel_size[1])
self.hidden_size = (
self.vision_config.hidden_size
* self.merge_kernel_size[0]
* self.merge_kernel_size[1]
)
self.pre_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-05)
self.act = GELUActivation()
@@ -208,15 +233,13 @@ class KeyeVL1_5Projector(nn.Module):
def forward(
self,
image_features: Union[torch.Tensor, tuple[torch.Tensor],
list[torch.Tensor]],
image_features: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]],
image_grid_thw: list[tuple[int, int, int]],
) -> Union[torch.Tensor, list[torch.Tensor]]:
m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)):
processed_features = list()
for image_feature, image_grid in zip(image_features,
image_grid_thw):
for image_feature, image_grid in zip(image_features, image_grid_thw):
t, h, w = image_grid
image_feature = rearrange(
image_feature,
@@ -238,8 +261,7 @@ class KeyeVL1_5Projector(nn.Module):
dims = image_features.shape[:-1]
dim = image_features.shape[-1]
image_features = image_features.view(np.prod(dims), dim)
hidden_states = self.pre_norm(image_features.view(
-1, self.hidden_size))
hidden_states = self.pre_norm(image_features.view(-1, self.hidden_size))
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
@@ -248,24 +270,28 @@ class KeyeVL1_5Projector(nn.Module):
class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
def get_max_frame_per_video(self) -> int:
return 2048
def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(
self,
) -> Mapping[str, Optional[int]]:
return {"image": None, "video": 1}
def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ):
image_grid_thw = hf_inputs.get("image_grid_thw",
torch.empty((0, 3), dtype=torch.int64))
def _keye_field_config(
hf_inputs: Mapping[str, torch.Tensor],
):
image_grid_thw = hf_inputs.get(
"image_grid_thw", torch.empty((0, 3), dtype=torch.int64)
)
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw",
torch.empty((0, 3), dtype=torch.int64))
video_grid_thw = hf_inputs.get(
"video_grid_thw", torch.empty((0, 3), dtype=torch.int64)
)
video_grid_thw = split_thw(video_grid_thw)
num_frames = hf_inputs.get("num_frames",
video_grid_thw[:, 0]).clone().tolist()
num_frames = hf_inputs.get("num_frames", video_grid_thw[:, 0]).clone().tolist()
video_num_patches = get_num_patches(video_grid_thw, num_frames)
@@ -285,22 +311,20 @@ def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ):
else:
j += 1
video_num_grids = torch.tensor(video_num_grids)
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches),
video_grid_thw=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_grids),
num_frames=MultiModalFieldConfig.batched("video"))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches
),
video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_num_patches),
video_grid_thw=MultiModalFieldConfig.flat_from_sizes("video", video_num_grids),
num_frames=MultiModalFieldConfig.batched("video"),
)
class KeyeVL1_5MultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
@@ -336,9 +360,7 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data)
class KeyeVL1_5MultiModalProcessor(
BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return KeyeVL1_5MultiModalDataParser()
@@ -349,8 +371,7 @@ class KeyeVL1_5MultiModalProcessor(
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token]
@@ -359,44 +380,49 @@ class KeyeVL1_5MultiModalProcessor(
merge_length = image_processor.merge_size**2
out_mm_kwargs_data = out_mm_kwargs.get_data()
frame_types: list[torch.Tensor] = \
hf_processor_mm_kwargs.get("frame_types", None)
timestamps: list[torch.Tensor] = \
hf_processor_mm_kwargs.get("timestamps", None)
frame_types: list[torch.Tensor] = hf_processor_mm_kwargs.get(
"frame_types", None
)
timestamps: list[torch.Tensor] = hf_processor_mm_kwargs.get("timestamps", None)
num_videos = mm_items.get_count("video", strict=False)
if frame_types is None:
frame_types = [None] * num_videos
assert len(frame_types) == num_videos, \
f"Number of frame_types={len(frame_types)} " \
assert len(frame_types) == num_videos, (
f"Number of frame_types={len(frame_types)} "
f"doesn't equal to number of videos={num_videos}"
)
if timestamps is None:
timestamps = [None] * num_videos
assert len(timestamps) == num_videos, \
f"Number of timestamps={len(timestamps)} " \
assert len(timestamps) == num_videos, (
f"Number of timestamps={len(timestamps)} "
f"doesn't equal to number of videos={num_videos}"
)
video_grid_thw = out_mm_kwargs_data.get(
'video_grid_thw', torch.empty((0, 3), dtype=torch.int64))
"video_grid_thw", torch.empty((0, 3), dtype=torch.int64)
)
num_frames = out_mm_kwargs_data.get(
'num_frames', torch.tensor([], dtype=torch.int64))
"num_frames", torch.tensor([], dtype=torch.int64)
)
assert len(num_frames) == num_videos, \
f"Size of num_frames={len(num_frames)} " \
assert len(num_frames) == num_videos, (
f"Size of num_frames={len(num_frames)} "
f"doesn't equal to number of videos={num_videos}"
)
video_grid_hws = split_thw(video_grid_thw)
assert int(num_frames.sum().tolist()) == video_grid_hws.shape[0], (
f"The first dimension of `video_grid_hws`={video_grid_hws.shape[0]}"
f"doesn't equal to num of frames.")
f"doesn't equal to num of frames."
)
cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()),
dim=-1)
cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()), dim=-1)
def get_replacement_keye(item_idx: int, modality: str):
"""
Args:
item_idx(int): The item index of modality to replace
item_idx(int): The item index of modality to replace
modality(str): The modality
"""
if modality == "image":
@@ -411,16 +437,15 @@ class KeyeVL1_5MultiModalProcessor(
video_timestamps = timestamps[item_idx]
video_frame_types = frame_types[item_idx]
grid_thw = video_grid_hws[
cu_seqlens[item_idx]:cu_seqlens[item_idx + 1]]
cu_seqlens[item_idx] : cu_seqlens[item_idx + 1]
]
nframes = grid_thw.shape[0]
if video_timestamps is None:
video_timestamps = [""] * nframes
else:
video_timestamps = [
format(ts, ".1f") for ts in video_timestamps
]
video_timestamps = [format(ts, ".1f") for ts in video_timestamps]
if video_frame_types is None:
video_frame_types = [0] * nframes
@@ -435,7 +460,8 @@ class KeyeVL1_5MultiModalProcessor(
placeholders.append(vocab[hf_processor.fast_end])
return PromptUpdateDetails.select_token_id(
placeholders, embed_token_id=video_token_id)
placeholders, embed_token_id=video_token_id
)
else:
raise ValueError(f"Unsupported modality {modality}")
@@ -444,7 +470,8 @@ class KeyeVL1_5MultiModalProcessor(
modality=modality,
target=[placeholder[modality]],
replacement=partial(get_replacement_keye, modality=modality),
) for modality in ("image", "video")
)
for modality in ("image", "video")
]
def _get_mm_fields_config(
@@ -456,8 +483,8 @@ class KeyeVL1_5MultiModalProcessor(
class KeyeVL1_5DummyInputsBuilder(
KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo]):
...
KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo]
): ...
@MULTIMODAL_REGISTRY.register_processor(
@@ -465,16 +492,17 @@ class KeyeVL1_5DummyInputsBuilder(
info=KeyeVL1_5ProcessingInfo,
dummy_inputs=KeyeVL1_5DummyInputsBuilder,
)
class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
return KeyeVL1_5Projector(text_config, vision_config, quant_config,
prefix)
class KeyeVL1_5ForConditionalGeneration(
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
):
def _build_projector(
self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return KeyeVL1_5Projector(text_config, vision_config, quant_config, prefix)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig = vllm_config.model_config.hf_config
@@ -482,7 +510,8 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
self, **kwargs: object
) -> Optional[KeyeVL1_5ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
@@ -505,7 +534,8 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[KeyeVL1_5VideoInputs]:
self, **kwargs: object
) -> Optional[KeyeVL1_5VideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
@@ -519,23 +549,27 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
num_frames=num_frames)
num_frames=num_frames,
)
if video_embeds is not None:
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
num_frames=num_frames)
return KeyeVL1_5VideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
num_frames=num_frames,
)
def _process_video_input(
self,
video_input: KeyeVL1_5VideoInputs) -> tuple[torch.Tensor, ...]:
self, video_input: KeyeVL1_5VideoInputs
) -> tuple[torch.Tensor, ...]:
video_type = video_input["type"]
video_grid_thw = split_thw(video_input["video_grid_thw"])
pixel_values_videos = video_input.get("pixel_values_videos", None)
video_embeds = self._process_video_embeds(video_type, video_grid_thw,
pixel_values_videos)
video_embeds = self._process_video_embeds(
video_type, video_grid_thw, pixel_values_videos
)
video_embeds = torch.concat(video_embeds, dim=0)
num_frames = video_input["num_frames"].clone().tolist()
@@ -543,10 +577,11 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
num_patches = get_num_patches(video_grid_thw, num_frames).tolist()
patch_cu_seqlens = torch.cumsum(
torch.tensor([0] + num_patches).detach().clone(), dim=-1)
patch_cu_seqlens = torch.div(patch_cu_seqlens,
self.merge_size**2,
rounding_mode="floor")
torch.tensor([0] + num_patches).detach().clone(), dim=-1
)
patch_cu_seqlens = torch.div(
patch_cu_seqlens, self.merge_size**2, rounding_mode="floor"
)
new_video_embeds = []
for idx in range(patch_cu_seqlens.shape[0] - 1):