[Misc] Move DP for ViT code inside model executor dir (#25459)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-23 17:20:52 +08:00
committed by GitHub
parent 9383cd6f10
commit babad6e5dd
13 changed files with 721 additions and 730 deletions

View File

@@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import math
from abc import ABC, abstractmethod
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union
import torch
from transformers import PretrainedConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
@@ -123,3 +128,277 @@ def resolve_visual_encoder_outputs(
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
def run_dp_sharded_vision_model(image_input: torch.Tensor,
vision_model: torch.nn.Module) -> torch.Tensor:
"""Run a vision model with data parallelism (DP) sharding. The function
will shard the input image tensor on the first dimension and run the vision
model
Args:
image_input (torch.Tensor): Image input tensor.
vision_model (torch.nn.Module): Vision model.
Returns:
torch.Tensor: Output image embeddings
"""
num_chunks = image_input.shape[0]
mp_world_size = get_tensor_model_parallel_world_size()
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
image_input_padded = torch.nn.functional.pad(image_input, pad)
rank = get_tensor_model_parallel_rank()
image_input_per_rank = image_input_padded[rank *
num_chunks_per_rank:(rank + 1) *
num_chunks_per_rank, ...]
vision_embeddings = vision_model(image_input_per_rank)
# Ensure tensor is contiguous before all_gather
vision_embeddings = vision_embeddings.contiguous()
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
dim=0)
vision_embeddings = vision_embeddings[:num_chunks, ...]
return vision_embeddings
def get_load_balance_assignment(
sizes: list[int],
num_gpus: int = 2,
) -> tuple[list[int], list[int], list[int]]:
"""
Generate load balancing assignment and metadata
for distributing data across GPUs.
The load is determined by the total image sizes,
not the number of images.
Args:
sizes: The size of each image
num_gpus: Number of GPUs to balance across
Returns:
shuffle_indices:
Indices to reorder data for balanced loading
gpu_sample_counts:
Number of samples assigned to each GPU
grouped_sizes_per_gpu:
Total size assigned to each GPU
Example:
```
sizes = [1000, 100, 200, 50]
num_gpus=2
```
"""
n_samples = len(sizes)
# Handle edge cases
if n_samples == 0:
return [], [0] * num_gpus, [0] * num_gpus
# Use greedy algorithm - balance by total size, not sample count
gpu_assignments = [list[int]() for _ in range(num_gpus)]
gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count
# Sort indices by size (largest first for better load balancing)
# sizes = [1000, 100, 200, 50]
# large_to_small_indices = [0, 2, 1, 3]
large_to_small_indices = sorted(range(n_samples),
key=lambda i: sizes[i],
reverse=True)
for idx in large_to_small_indices:
# Find GPU with minimum current load (by total size)
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
gpu_assignments[min_gpu].append(idx)
gpu_loads[min_gpu] += sizes[idx]
# Create shuffle indices and counts
shuffle_indices = list[int]()
gpu_sample_counts = list[int]()
for gpu_id in range(num_gpus):
# GPU_0 = [1000] = [0]
# GPU_1 = [200, 100, 50] = [2, 1, 3]
# shuffle_indices = [0, 2, 1, 3]
shuffle_indices.extend(gpu_assignments[gpu_id])
# GPU_0 = [1]
# GPU_1 = [3]
# gpu_sample_counts = [1, 3]
gpu_sample_counts.append(len(gpu_assignments[gpu_id]))
return (shuffle_indices, gpu_sample_counts, gpu_loads)
def run_dp_sharded_mrope_vision_model(
vision_model: torch.nn.Module,
pixel_values: torch.Tensor,
grid_thw_list: list[list[int]],
*,
rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[torch.Tensor, ...]:
"""Run a vision model with data parallelism (DP) sharding.
The function will shard the input image tensor on the
first dimension and run the vision model.
This function is used to run the vision model with mrope.
Args:
vision_model (torch.nn.Module): Vision model.
pixel_values (torch.Tensor): Image/Video input tensor.
grid_thw_list: List of grid dimensions for each image
rope_type: Type of rope used in the vision model.
Different rope types have different dimension to do ViT.
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
"rope_2d" for 2D rope (e.g., Kimi-VL)
Returns:
torch.Tensor: Output image embeddings
Example:
```
vision_model.out_hidden_size = 64
vision_model.spatial_merge_size = 2
pixel_values.shape = (1350, channel)
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
tp_size=2
```
"""
tp_size = get_tensor_model_parallel_world_size()
# GPU_0 tp_rank_local = 0
# GPU_1 tp_rank_local = 1
tp_rank_local = get_tensor_model_parallel_rank()
# patches_per_image = [1000, 100, 200, 50]
patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
# patches_per_image = [0, 1000, 1100, 1300, 1350]
cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]
# Get load balancing assignment with all metadata
# image_to_tp_rank = [0, 2, 1, 3]
# gpu_sample_counts = [1, 3]
# grouped_pixel_values_len = [1000, 350]
(image_to_tp_rank, gpu_sample_counts,
grouped_pixel_values_len) = get_load_balance_assignment(
patches_per_image, tp_size)
# cu_gpu_sample_counts = [0, 1, 4]
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
# GPU_0 image_idxs_local = [0]
# GPU_1 image_idxs_local = [2, 1, 3]
image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
cum_gpu_sample_counts[tp_rank_local +
1]]
# Get the pixel values for the local images based on the image_idxs_local
if len(image_idxs_local) > 0:
pixel_values_local = torch.cat([
pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
for i in image_idxs_local
])
else:
# Handle case where this rank has no images
pixel_values_local = torch.empty((0, pixel_values.shape[1]),
device=pixel_values.device,
dtype=pixel_values.dtype)
# embed_dim_reduction_factor = 2 * 2
if rope_type == "rope_2d":
embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
vision_model.merge_kernel_size[1])
else:
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
vision_model.spatial_merge_size)
# Find the max length across all ranks
# The output embedding of every DP rank has to be
# padded to this length for tensor_model_parallel_all_gather
# to work
max_len_per_rank = max(
grouped_pixel_values_len) // embed_dim_reduction_factor
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
# Run the vision model on the local pixel_values_local
if rope_type == "rope_2d":
if pixel_values_local.shape[0] > 0:
image_embeds_local = vision_model(
pixel_values_local, torch.tensor(local_grid_thw_list))
if isinstance(image_embeds_local, list):
image_embeds_local = torch.cat(image_embeds_local, dim=0)
else:
out_dim = getattr(vision_model.config, "hidden_size", None)
image_embeds_local = torch.empty(
(0, embed_dim_reduction_factor, out_dim),
device=pixel_values.device,
dtype=pixel_values.dtype)
else:
if pixel_values_local.shape[0] > 0:
image_embeds_local = vision_model(pixel_values_local,
local_grid_thw_list)
else:
# Handle empty case
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
device=pixel_values.device,
dtype=pixel_values.dtype)
# Pad the output based on max_len_per_rank
# for tensor_model_parallel_all_gather to work
current_len = image_embeds_local.shape[0]
if current_len < max_len_per_rank:
padding_size = max_len_per_rank - current_len
if rope_type == "rope_2d":
padding = torch.empty((padding_size, image_embeds_local.shape[1],
image_embeds_local.shape[2]),
dtype=image_embeds_local.dtype,
device=image_embeds_local.device)
else:
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
dtype=image_embeds_local.dtype,
device=image_embeds_local.device)
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
dim=0)
else:
image_embeds_local_padded = image_embeds_local
# Do all_gather to collect embeddings from all ranks
gathered_embeds = tensor_model_parallel_all_gather(
image_embeds_local_padded, dim=0)
# Remove padding and reconstruct per-rank embeddings
rank_embeddings = list[torch.Tensor]()
for rank in range(tp_size):
start_idx = rank * max_len_per_rank
end_idx = start_idx + (grouped_pixel_values_len[rank] //
embed_dim_reduction_factor)
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
for patch_size in patches_per_image]
# Reconstruct embeddings in the original order
original_order_embeddings = [None] * len(grid_thw_list)
current_idx = 0
for rank in range(tp_size):
count = gpu_sample_counts[rank]
if count > 0:
# Get images assigned to this rank in shuffled order
# GPU_0 = image_idxs_local [0]
# GPU_1 = image_idxs_local [2, 1, 3]
rank_images = image_to_tp_rank[current_idx:current_idx + count]
rank_embed = rank_embeddings[rank]
# Split rank embeddings back to individual images
embed_start = 0
for img_idx in rank_images:
img_patches = patches_per_output_image[img_idx]
original_order_embeddings[img_idx] = rank_embed[
embed_start:embed_start + img_patches]
embed_start += img_patches
current_idx += count
out_embeddings = tuple(embed for embed in original_order_embeddings
if embed is not None)
assert len(out_embeddings) == len(
original_order_embeddings), "Found unassigned embeddings"
return out_embeddings