Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,16 +4,17 @@
|
||||
import itertools
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
|
||||
TypeVar, Union)
|
||||
from typing import Callable, Final, Generic, Literal, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -27,7 +28,6 @@ class _RootConfig(Protocol[_C]):
|
||||
|
||||
|
||||
class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
|
||||
def __init__(self, hf_config: _RootConfig[_C]) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -60,8 +60,7 @@ class VisionLanguageConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
|
||||
|
||||
def get_vision_encoder_info(
|
||||
hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
|
||||
def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
|
||||
# Avoid circular imports
|
||||
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
||||
@@ -164,12 +163,13 @@ def resolve_visual_encoder_outputs(
|
||||
"""
|
||||
if select_layers is None:
|
||||
if not isinstance(encoder_outputs, torch.Tensor):
|
||||
raise ValueError("Expected only a single encoder output when "
|
||||
"`select_layers` is not provided")
|
||||
raise ValueError(
|
||||
"Expected only a single encoder output when "
|
||||
"`select_layers` is not provided"
|
||||
)
|
||||
|
||||
if feature_select_strategy is not None:
|
||||
select_features = _get_vision_feature_selector(
|
||||
feature_select_strategy)
|
||||
select_features = _get_vision_feature_selector(feature_select_strategy)
|
||||
encoder_outputs = select_features(encoder_outputs)
|
||||
|
||||
if post_layer_norm is not None:
|
||||
@@ -178,8 +178,9 @@ def resolve_visual_encoder_outputs(
|
||||
return encoder_outputs
|
||||
|
||||
if max_possible_layers is None:
|
||||
raise ValueError("`max_possible_layers` must be provided "
|
||||
"alongside `select_layers`")
|
||||
raise ValueError(
|
||||
"`max_possible_layers` must be provided alongside `select_layers`"
|
||||
)
|
||||
|
||||
# Get the hidden states corresponding to the layer indices.
|
||||
# Negative values are relative to the full visual encoder,
|
||||
@@ -191,7 +192,8 @@ def resolve_visual_encoder_outputs(
|
||||
offset = max_possible_layers - num_loaded_layers
|
||||
hs_pool = [
|
||||
encoder_outputs[layer_idx]
|
||||
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
|
||||
if layer_idx >= 0
|
||||
else encoder_outputs[layer_idx + offset]
|
||||
for layer_idx in select_layers
|
||||
]
|
||||
|
||||
@@ -207,9 +209,10 @@ def resolve_visual_encoder_outputs(
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
vision_model: torch.nn.Module) -> torch.Tensor:
|
||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||
def run_dp_sharded_vision_model(
|
||||
image_input: torch.Tensor, vision_model: torch.nn.Module
|
||||
) -> torch.Tensor:
|
||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||
will shard the input image tensor on the first dimension and run the vision
|
||||
model
|
||||
|
||||
@@ -224,18 +227,17 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
mp_world_size = get_tensor_model_parallel_world_size()
|
||||
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
||||
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
||||
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
||||
pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
||||
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
image_input_per_rank = image_input_padded[rank *
|
||||
num_chunks_per_rank:(rank + 1) *
|
||||
num_chunks_per_rank, ...]
|
||||
image_input_per_rank = image_input_padded[
|
||||
rank * num_chunks_per_rank : (rank + 1) * num_chunks_per_rank, ...
|
||||
]
|
||||
|
||||
vision_embeddings = vision_model(image_input_per_rank)
|
||||
# Ensure tensor is contiguous before all_gather
|
||||
vision_embeddings = vision_embeddings.contiguous()
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||
dim=0)
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0)
|
||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||
return vision_embeddings
|
||||
|
||||
@@ -245,27 +247,27 @@ def get_load_balance_assignment(
|
||||
num_gpus: int = 2,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Generate load balancing assignment and metadata
|
||||
Generate load balancing assignment and metadata
|
||||
for distributing data across GPUs.
|
||||
The load is determined by the total image sizes,
|
||||
not the number of images.
|
||||
|
||||
|
||||
Args:
|
||||
sizes: The size of each image
|
||||
num_gpus: Number of GPUs to balance across
|
||||
|
||||
|
||||
Returns:
|
||||
shuffle_indices:
|
||||
shuffle_indices:
|
||||
Indices to reorder data for balanced loading
|
||||
gpu_sample_counts:
|
||||
gpu_sample_counts:
|
||||
Number of samples assigned to each GPU
|
||||
grouped_sizes_per_gpu:
|
||||
grouped_sizes_per_gpu:
|
||||
Total size assigned to each GPU
|
||||
|
||||
|
||||
Example:
|
||||
```
|
||||
sizes = [1000, 100, 200, 50]
|
||||
num_gpus=2
|
||||
num_gpus = 2
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -283,9 +285,9 @@ def get_load_balance_assignment(
|
||||
# Sort indices by size (largest first for better load balancing)
|
||||
# sizes = [1000, 100, 200, 50]
|
||||
# large_to_small_indices = [0, 2, 1, 3]
|
||||
large_to_small_indices = sorted(range(n_samples),
|
||||
key=lambda i: sizes[i],
|
||||
reverse=True)
|
||||
large_to_small_indices = sorted(
|
||||
range(n_samples), key=lambda i: sizes[i], reverse=True
|
||||
)
|
||||
|
||||
for idx in large_to_small_indices:
|
||||
# Find GPU with minimum current load (by total size)
|
||||
@@ -316,11 +318,11 @@ def run_dp_sharded_mrope_vision_model(
|
||||
*,
|
||||
rope_type: Literal["rope_3d", "rope_2d"],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Run a vision model with data parallelism (DP) sharding.
|
||||
The function will shard the input image tensor on the
|
||||
"""Run a vision model with data parallelism (DP) sharding.
|
||||
The function will shard the input image tensor on the
|
||||
first dimension and run the vision model.
|
||||
This function is used to run the vision model with mrope.
|
||||
|
||||
|
||||
Args:
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
pixel_values (torch.Tensor): Image/Video input tensor.
|
||||
@@ -338,7 +340,7 @@ def run_dp_sharded_mrope_vision_model(
|
||||
vision_model.spatial_merge_size = 2
|
||||
pixel_values.shape = (1350, channel)
|
||||
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
|
||||
tp_size=2
|
||||
tp_size = 2
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -357,51 +359,57 @@ def run_dp_sharded_mrope_vision_model(
|
||||
# image_to_tp_rank = [0, 2, 1, 3]
|
||||
# gpu_sample_counts = [1, 3]
|
||||
# grouped_pixel_values_len = [1000, 350]
|
||||
(image_to_tp_rank, gpu_sample_counts,
|
||||
grouped_pixel_values_len) = get_load_balance_assignment(
|
||||
patches_per_image, tp_size)
|
||||
(image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = (
|
||||
get_load_balance_assignment(patches_per_image, tp_size)
|
||||
)
|
||||
|
||||
# cu_gpu_sample_counts = [0, 1, 4]
|
||||
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
|
||||
|
||||
# GPU_0 image_idxs_local = [0]
|
||||
# GPU_1 image_idxs_local = [2, 1, 3]
|
||||
image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
|
||||
cum_gpu_sample_counts[tp_rank_local +
|
||||
1]]
|
||||
image_idxs_local = image_to_tp_rank[
|
||||
cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1]
|
||||
]
|
||||
|
||||
# Get the pixel values for the local images based on the image_idxs_local
|
||||
if len(image_idxs_local) > 0:
|
||||
pixel_values_local = torch.cat([
|
||||
pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
|
||||
for i in image_idxs_local
|
||||
])
|
||||
pixel_values_local = torch.cat(
|
||||
[
|
||||
pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]]
|
||||
for i in image_idxs_local
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Handle case where this rank has no images
|
||||
pixel_values_local = torch.empty((0, pixel_values.shape[1]),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
pixel_values_local = torch.empty(
|
||||
(0, pixel_values.shape[1]),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype,
|
||||
)
|
||||
# embed_dim_reduction_factor = 2 * 2
|
||||
if rope_type == "rope_2d":
|
||||
embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
|
||||
vision_model.merge_kernel_size[1])
|
||||
embed_dim_reduction_factor = (
|
||||
vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1]
|
||||
)
|
||||
else:
|
||||
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
|
||||
vision_model.spatial_merge_size)
|
||||
embed_dim_reduction_factor = (
|
||||
vision_model.spatial_merge_size * vision_model.spatial_merge_size
|
||||
)
|
||||
|
||||
# Find the max length across all ranks
|
||||
# The output embedding of every DP rank has to be
|
||||
# padded to this length for tensor_model_parallel_all_gather
|
||||
# to work
|
||||
max_len_per_rank = max(
|
||||
grouped_pixel_values_len) // embed_dim_reduction_factor
|
||||
max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor
|
||||
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
|
||||
|
||||
# Run the vision model on the local pixel_values_local
|
||||
if rope_type == "rope_2d":
|
||||
if pixel_values_local.shape[0] > 0:
|
||||
image_embeds_local = vision_model(
|
||||
pixel_values_local, torch.tensor(local_grid_thw_list))
|
||||
pixel_values_local, torch.tensor(local_grid_thw_list)
|
||||
)
|
||||
if isinstance(image_embeds_local, list):
|
||||
image_embeds_local = torch.cat(image_embeds_local, dim=0)
|
||||
else:
|
||||
@@ -409,16 +417,18 @@ def run_dp_sharded_mrope_vision_model(
|
||||
image_embeds_local = torch.empty(
|
||||
(0, embed_dim_reduction_factor, out_dim),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
dtype=pixel_values.dtype,
|
||||
)
|
||||
else:
|
||||
if pixel_values_local.shape[0] > 0:
|
||||
image_embeds_local = vision_model(pixel_values_local,
|
||||
local_grid_thw_list)
|
||||
image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list)
|
||||
else:
|
||||
# Handle empty case
|
||||
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
image_embeds_local = torch.empty(
|
||||
(0, vision_model.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype,
|
||||
)
|
||||
|
||||
# Pad the output based on max_len_per_rank
|
||||
# for tensor_model_parallel_all_gather to work
|
||||
@@ -426,33 +436,40 @@ def run_dp_sharded_mrope_vision_model(
|
||||
if current_len < max_len_per_rank:
|
||||
padding_size = max_len_per_rank - current_len
|
||||
if rope_type == "rope_2d":
|
||||
padding = torch.empty((padding_size, image_embeds_local.shape[1],
|
||||
image_embeds_local.shape[2]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device)
|
||||
padding = torch.empty(
|
||||
(
|
||||
padding_size,
|
||||
image_embeds_local.shape[1],
|
||||
image_embeds_local.shape[2],
|
||||
),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device,
|
||||
)
|
||||
else:
|
||||
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device)
|
||||
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
|
||||
dim=0)
|
||||
padding = torch.empty(
|
||||
(padding_size, image_embeds_local.shape[1]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device,
|
||||
)
|
||||
image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0)
|
||||
else:
|
||||
image_embeds_local_padded = image_embeds_local
|
||||
|
||||
# Do all_gather to collect embeddings from all ranks
|
||||
gathered_embeds = tensor_model_parallel_all_gather(
|
||||
image_embeds_local_padded, dim=0)
|
||||
gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0)
|
||||
|
||||
# Remove padding and reconstruct per-rank embeddings
|
||||
rank_embeddings = list[torch.Tensor]()
|
||||
for rank in range(tp_size):
|
||||
start_idx = rank * max_len_per_rank
|
||||
end_idx = start_idx + (grouped_pixel_values_len[rank] //
|
||||
embed_dim_reduction_factor)
|
||||
end_idx = start_idx + (
|
||||
grouped_pixel_values_len[rank] // embed_dim_reduction_factor
|
||||
)
|
||||
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
|
||||
|
||||
patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
|
||||
for patch_size in patches_per_image]
|
||||
patches_per_output_image = [
|
||||
(patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image
|
||||
]
|
||||
|
||||
# Reconstruct embeddings in the original order
|
||||
original_order_embeddings = [None] * len(grid_thw_list)
|
||||
@@ -463,7 +480,7 @@ def run_dp_sharded_mrope_vision_model(
|
||||
# Get images assigned to this rank in shuffled order
|
||||
# GPU_0 = image_idxs_local [0]
|
||||
# GPU_1 = image_idxs_local [2, 1, 3]
|
||||
rank_images = image_to_tp_rank[current_idx:current_idx + count]
|
||||
rank_images = image_to_tp_rank[current_idx : current_idx + count]
|
||||
|
||||
rank_embed = rank_embeddings[rank]
|
||||
# Split rank embeddings back to individual images
|
||||
@@ -471,11 +488,14 @@ def run_dp_sharded_mrope_vision_model(
|
||||
for img_idx in rank_images:
|
||||
img_patches = patches_per_output_image[img_idx]
|
||||
original_order_embeddings[img_idx] = rank_embed[
|
||||
embed_start:embed_start + img_patches]
|
||||
embed_start : embed_start + img_patches
|
||||
]
|
||||
embed_start += img_patches
|
||||
current_idx += count
|
||||
out_embeddings = tuple(embed for embed in original_order_embeddings
|
||||
if embed is not None)
|
||||
assert len(out_embeddings) == len(
|
||||
original_order_embeddings), "Found unassigned embeddings"
|
||||
out_embeddings = tuple(
|
||||
embed for embed in original_order_embeddings if embed is not None
|
||||
)
|
||||
assert len(out_embeddings) == len(original_order_embeddings), (
|
||||
"Found unassigned embeddings"
|
||||
)
|
||||
return out_embeddings
|
||||
|
||||
Reference in New Issue
Block a user