[Misc] Move DP for ViT code inside model executor dir (#25459)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,10 +1,20 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import math
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs
|
from tests.utils import multi_gpu_test
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
|
from vllm.model_executor.models.vision import (
|
||||||
|
get_load_balance_assignment, resolve_visual_encoder_outputs,
|
||||||
|
run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import get_open_port, update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
|
|||||||
post_layer_norm=None,
|
post_layer_norm=None,
|
||||||
max_possible_layers=max_possible_layers)
|
max_possible_layers=max_possible_layers)
|
||||||
assert torch.equal(torch.tensor(expected_features), output_tensor)
|
assert torch.equal(torch.tensor(expected_features), output_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLinearModel(torch.nn.Module):
|
||||||
|
"""A simple linear vision model for testing."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
self.flatten = torch.nn.Flatten()
|
||||||
|
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# Flatten the input and apply linear transformation
|
||||||
|
x = self.flatten(x)
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size",
|
||||||
|
[
|
||||||
|
1, # Single image
|
||||||
|
4, # Small batch
|
||||||
|
5, # Odd batch size (for testing padding)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_run_dp_sharded_vision_model(batch_size: int):
|
||||||
|
world_size = 2
|
||||||
|
# Launch processes
|
||||||
|
mp.spawn(
|
||||||
|
run_dp_sharded_vision_model_vs_direct,
|
||||||
|
args=(
|
||||||
|
world_size,
|
||||||
|
batch_size,
|
||||||
|
get_open_port(),
|
||||||
|
),
|
||||||
|
nprocs=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||||
|
batch_size: int, master_port: int):
|
||||||
|
"""
|
||||||
|
Test that run_dp_sharded_vision_model produces the same results as
|
||||||
|
calling the model directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
|
device = f"{current_platform.device_name}:{local_rank}"
|
||||||
|
current_platform.set_device(device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
update_environment_variables({
|
||||||
|
'RANK': str(local_rank),
|
||||||
|
'LOCAL_RANK': str(local_rank),
|
||||||
|
'WORLD_SIZE': str(world_size),
|
||||||
|
'MASTER_ADDR': 'localhost',
|
||||||
|
'MASTER_PORT': str(master_port),
|
||||||
|
})
|
||||||
|
|
||||||
|
# initialize distributed
|
||||||
|
init_distributed_environment()
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
|
# Create a test input tensor
|
||||||
|
image_input = torch.randn(batch_size, 3, 224, 224)
|
||||||
|
|
||||||
|
# Create a simple linear model
|
||||||
|
vision_model = SimpleLinearModel()
|
||||||
|
|
||||||
|
# Run the model directly on the full input
|
||||||
|
with torch.inference_mode():
|
||||||
|
direct_output = vision_model(image_input)
|
||||||
|
|
||||||
|
# Run the model through the sharded function
|
||||||
|
with torch.inference_mode():
|
||||||
|
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
|
||||||
|
|
||||||
|
# Check that the world size is set up correctly
|
||||||
|
assert get_tensor_model_parallel_world_size() == world_size
|
||||||
|
|
||||||
|
# Check that the outputs have the same shape
|
||||||
|
assert direct_output.shape == sharded_output.shape
|
||||||
|
|
||||||
|
# Check that the outputs are close (they should be identical)
|
||||||
|
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
|
||||||
|
"expected_grouped_sizes_per_gpu,test_description",
|
||||||
|
[
|
||||||
|
# Empty input
|
||||||
|
([], 2, [], [0, 0], [0, 0], "empty input"),
|
||||||
|
|
||||||
|
# Fewer samples than GPUs
|
||||||
|
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
|
||||||
|
], "fewer samples than GPUs"),
|
||||||
|
|
||||||
|
# Single GPU
|
||||||
|
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
|
||||||
|
|
||||||
|
# Balanced assignment
|
||||||
|
([100, 100, 100, 100
|
||||||
|
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
|
||||||
|
|
||||||
|
# Unbalanced sizes - this one is trickier since the algorithm is greedy
|
||||||
|
([1000, 100, 200, 50], 2, [0, 2, 1, 3
|
||||||
|
], [1, 3], [1000, 350], "unbalanced sizes"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_load_balance_assignment_cases(sizes, num_gpus,
|
||||||
|
expected_shuffle_indices,
|
||||||
|
expected_gpu_sample_counts,
|
||||||
|
expected_grouped_sizes_per_gpu,
|
||||||
|
test_description):
|
||||||
|
"""Test get_load_balance_assignment with various input cases."""
|
||||||
|
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
|
||||||
|
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
|
||||||
|
|
||||||
|
# Common assertions for all cases
|
||||||
|
assert len(shuffle_indices) == len(sizes)
|
||||||
|
assert len(gpu_sample_counts) == num_gpus
|
||||||
|
assert len(grouped_sizes_per_gpu) == num_gpus
|
||||||
|
assert sum(gpu_sample_counts) == len(sizes)
|
||||||
|
|
||||||
|
assert shuffle_indices == expected_shuffle_indices
|
||||||
|
|
||||||
|
assert gpu_sample_counts == expected_gpu_sample_counts
|
||||||
|
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleMRopeVisionModel(torch.nn.Module):
|
||||||
|
"""A simple vision model for testing mrope functionality."""
|
||||||
|
|
||||||
|
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.out_hidden_size = out_hidden_size
|
||||||
|
self.linear = torch.nn.Linear(768, out_hidden_size)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor,
|
||||||
|
grid_thw_list: list[list[int]]):
|
||||||
|
"""Simple forward pass that simulates spatial merging."""
|
||||||
|
# Apply linear transformation
|
||||||
|
embeddings = self.linear(pixel_values)
|
||||||
|
|
||||||
|
# Simulate spatial merging by reducing the number of patches
|
||||||
|
merge_factor = self.spatial_merge_size * self.spatial_merge_size
|
||||||
|
|
||||||
|
# Group patches and merge spatially
|
||||||
|
merged_embeddings = []
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
for grid_thw in grid_thw_list:
|
||||||
|
num_patches = math.prod(grid_thw)
|
||||||
|
end_idx = start_idx + num_patches
|
||||||
|
|
||||||
|
# Get patches for this image
|
||||||
|
image_patches = embeddings[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Simulate spatial merging by averaging groups of patches
|
||||||
|
merged_patches = num_patches // merge_factor
|
||||||
|
if merged_patches > 0:
|
||||||
|
# Reshape and average to simulate merging
|
||||||
|
reshaped = image_patches[:merged_patches * merge_factor].view(
|
||||||
|
merged_patches, merge_factor, -1)
|
||||||
|
merged = reshaped.mean(dim=1)
|
||||||
|
merged_embeddings.append(merged)
|
||||||
|
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
if merged_embeddings:
|
||||||
|
return torch.cat(merged_embeddings, dim=0)
|
||||||
|
else:
|
||||||
|
return torch.empty((0, self.out_hidden_size),
|
||||||
|
device=pixel_values.device,
|
||||||
|
dtype=pixel_values.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size",
|
||||||
|
[
|
||||||
|
1, # Single image
|
||||||
|
3, # Small batch
|
||||||
|
5, # Odd batch size (for testing padding)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
|
||||||
|
world_size = 2
|
||||||
|
# Launch processes
|
||||||
|
mp.spawn(
|
||||||
|
run_dp_sharded_mrope_vision_model_vs_direct,
|
||||||
|
args=(
|
||||||
|
world_size,
|
||||||
|
batch_size,
|
||||||
|
get_open_port(),
|
||||||
|
),
|
||||||
|
nprocs=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
||||||
|
world_size: int,
|
||||||
|
batch_size: int,
|
||||||
|
master_port: int):
|
||||||
|
"""
|
||||||
|
Test that run_dp_sharded_mrope_vision_model produces the same results as
|
||||||
|
calling the model directly.
|
||||||
|
"""
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
device = f"{current_platform.device_name}:{local_rank}"
|
||||||
|
current_platform.set_device(device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
update_environment_variables({
|
||||||
|
'RANK': str(local_rank),
|
||||||
|
'LOCAL_RANK': str(local_rank),
|
||||||
|
'WORLD_SIZE': str(world_size),
|
||||||
|
'MASTER_ADDR': 'localhost',
|
||||||
|
'MASTER_PORT': str(master_port),
|
||||||
|
})
|
||||||
|
|
||||||
|
# initialize distributed
|
||||||
|
init_distributed_environment()
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
grid_thw_list = []
|
||||||
|
pixel_values_list = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Varying image sizes for better testing
|
||||||
|
t, h, w = 1, 4 + i, 4 + i
|
||||||
|
grid_thw_list.append([t, h, w])
|
||||||
|
|
||||||
|
num_patches = t * h * w
|
||||||
|
# Create random pixel values for this image
|
||||||
|
image_pixels = torch.randn(num_patches, 768)
|
||||||
|
pixel_values_list.append(image_pixels)
|
||||||
|
|
||||||
|
# Concatenate all pixel values
|
||||||
|
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||||
|
|
||||||
|
# Create a simple mrope vision model
|
||||||
|
vision_model = SimpleMRopeVisionModel()
|
||||||
|
|
||||||
|
# Run the model directly on the full input (only on rank 0)
|
||||||
|
if local_rank == 0:
|
||||||
|
with torch.inference_mode():
|
||||||
|
direct_output = vision_model(pixel_values, grid_thw_list)
|
||||||
|
|
||||||
|
# Run the model through the sharded function
|
||||||
|
with torch.inference_mode():
|
||||||
|
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||||
|
pixel_values,
|
||||||
|
grid_thw_list,
|
||||||
|
rope_type="rope_3d")
|
||||||
|
sharded_output = torch.cat(sharded_output, dim=0)
|
||||||
|
|
||||||
|
# Check that the world size is set up correctly
|
||||||
|
assert get_tensor_model_parallel_world_size() == world_size
|
||||||
|
|
||||||
|
# Compare outputs (only on rank 0)
|
||||||
|
if local_rank == 0:
|
||||||
|
# Check that the outputs have the same shape
|
||||||
|
assert direct_output.shape == sharded_output.shape
|
||||||
|
# Check that the outputs are close (they should be identical)
|
||||||
|
assert torch.allclose(direct_output,
|
||||||
|
sharded_output,
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_run_dp_sharded_mrope_vision_model_empty_input():
|
||||||
|
world_size = 2
|
||||||
|
mp.spawn(
|
||||||
|
run_dp_sharded_mrope_vision_model_empty_input_worker,
|
||||||
|
args=(world_size, get_open_port()),
|
||||||
|
nprocs=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||||
|
local_rank: int, world_size: int, master_port: int):
|
||||||
|
"""Test run_dp_sharded_mrope_vision_model with empty input."""
|
||||||
|
# Set up distributed environment
|
||||||
|
device = f"{current_platform.device_name}:{local_rank}"
|
||||||
|
current_platform.set_device(device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
update_environment_variables({
|
||||||
|
'RANK': str(local_rank),
|
||||||
|
'LOCAL_RANK': str(local_rank),
|
||||||
|
'WORLD_SIZE': str(world_size),
|
||||||
|
'MASTER_ADDR': 'localhost',
|
||||||
|
'MASTER_PORT': str(master_port),
|
||||||
|
})
|
||||||
|
|
||||||
|
init_distributed_environment()
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
|
# Create empty inputs
|
||||||
|
pixel_values = torch.empty((0, 768))
|
||||||
|
grid_thw_list: list[list[int]] = []
|
||||||
|
|
||||||
|
vision_model = SimpleMRopeVisionModel()
|
||||||
|
|
||||||
|
# Should handle empty input gracefully
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = run_dp_sharded_mrope_vision_model(vision_model,
|
||||||
|
pixel_values,
|
||||||
|
grid_thw_list,
|
||||||
|
rope_type="rope_3d")
|
||||||
|
|
||||||
|
assert len(output) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=4)
|
||||||
|
def test_run_dp_sharded_mrope_vision_model_uneven_load():
|
||||||
|
world_size = 4
|
||||||
|
mp.spawn(
|
||||||
|
run_dp_sharded_mrope_vision_model_uneven_load_worker,
|
||||||
|
args=(world_size, get_open_port()),
|
||||||
|
nprocs=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||||
|
local_rank: int, world_size: int, master_port: int):
|
||||||
|
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
|
||||||
|
# Set up distributed environment
|
||||||
|
current_platform.seed_everything(123)
|
||||||
|
device = f"{current_platform.device_name}:{local_rank}"
|
||||||
|
current_platform.set_device(device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
update_environment_variables({
|
||||||
|
'RANK': str(local_rank),
|
||||||
|
'LOCAL_RANK': str(local_rank),
|
||||||
|
'WORLD_SIZE': str(world_size),
|
||||||
|
'MASTER_ADDR': 'localhost',
|
||||||
|
'MASTER_PORT': str(master_port),
|
||||||
|
})
|
||||||
|
|
||||||
|
init_distributed_environment()
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
|
# Create images with very different sizes
|
||||||
|
grid_thw_list = [
|
||||||
|
[1, 2, 2], # Small: 4 patches
|
||||||
|
[1, 8, 8], # Large: 64 patches
|
||||||
|
[1, 3, 3], # Medium: 9 patches
|
||||||
|
]
|
||||||
|
|
||||||
|
pixel_values_list = []
|
||||||
|
for grid_thw in grid_thw_list:
|
||||||
|
num_patches = math.prod(grid_thw)
|
||||||
|
image_pixels = torch.randn(num_patches, 768)
|
||||||
|
pixel_values_list.append(image_pixels)
|
||||||
|
|
||||||
|
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||||
|
vision_model = SimpleMRopeVisionModel()
|
||||||
|
|
||||||
|
# Should handle uneven distribution without errors
|
||||||
|
with torch.inference_mode():
|
||||||
|
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
|
||||||
|
pixel_values,
|
||||||
|
grid_thw_list,
|
||||||
|
rope_type="rope_3d")
|
||||||
|
|
||||||
|
# Verify output shape is reasonable
|
||||||
|
merge_factor = vision_model.spatial_merge_size**2
|
||||||
|
expected_output_patches = list(
|
||||||
|
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
|
||||||
|
|
||||||
|
for i, output in enumerate(output_tuple):
|
||||||
|
assert output.shape[0] == expected_output_patches[i]
|
||||||
|
assert output.shape[1] == vision_model.out_hidden_size
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
|
||||||
|
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
|
||||||
|
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
|
||||||
|
device = current_platform.device_type
|
||||||
|
|
||||||
|
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
|
||||||
|
pixel_values_list = []
|
||||||
|
|
||||||
|
for grid_thw in grid_thw_list:
|
||||||
|
num_patches = math.prod(grid_thw)
|
||||||
|
image_pixels = torch.randn(num_patches, 768, device=device)
|
||||||
|
pixel_values_list.append(image_pixels)
|
||||||
|
|
||||||
|
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||||
|
vision_model = SimpleMRopeVisionModel(
|
||||||
|
spatial_merge_size=spatial_merge_size).to(device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = vision_model(pixel_values, grid_thw_list)
|
||||||
|
|
||||||
|
# Verify output dimensions based on spatial merging
|
||||||
|
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
|
||||||
|
merge_factor = spatial_merge_size**2
|
||||||
|
expected_output_patches = total_patches // merge_factor
|
||||||
|
|
||||||
|
assert output.shape[0] == expected_output_patches
|
||||||
|
assert output.shape[1] == vision_model.out_hidden_size
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import math
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||||
@@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
from tests.utils import multi_gpu_test
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
|
||||||
initialize_model_parallel)
|
|
||||||
from vllm.multimodal.image import convert_image_mode
|
from vllm.multimodal.image import convert_image_mode
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
|
from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
|
||||||
get_load_balance_assignment,
|
|
||||||
run_dp_sharded_mrope_vision_model,
|
|
||||||
run_dp_sharded_vision_model)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import get_open_port, update_environment_variables
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||||
@@ -404,415 +392,3 @@ def test_argsort_mm_positions():
|
|||||||
modality_idxs = argsort_mm_positions(mm_positions)
|
modality_idxs = argsort_mm_positions(mm_positions)
|
||||||
|
|
||||||
assert modality_idxs == expected_modality_idxs
|
assert modality_idxs == expected_modality_idxs
|
||||||
|
|
||||||
|
|
||||||
class SimpleLinearModel(torch.nn.Module):
|
|
||||||
"""A simple linear vision model for testing."""
|
|
||||||
|
|
||||||
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
|
|
||||||
super().__init__()
|
|
||||||
self.flatten = torch.nn.Flatten()
|
|
||||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
# Flatten the input and apply linear transformation
|
|
||||||
x = self.flatten(x)
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"batch_size",
|
|
||||||
[
|
|
||||||
1, # Single image
|
|
||||||
4, # Small batch
|
|
||||||
5, # Odd batch size (for testing padding)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_run_dp_sharded_vision_model(batch_size: int):
|
|
||||||
world_size = 2
|
|
||||||
# Launch processes
|
|
||||||
mp.spawn(
|
|
||||||
run_dp_sharded_vision_model_vs_direct,
|
|
||||||
args=(
|
|
||||||
world_size,
|
|
||||||
batch_size,
|
|
||||||
get_open_port(),
|
|
||||||
),
|
|
||||||
nprocs=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
|
||||||
batch_size: int, master_port: int):
|
|
||||||
"""
|
|
||||||
Test that run_dp_sharded_vision_model produces the same results as
|
|
||||||
calling the model directly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Set random seed for reproducibility
|
|
||||||
current_platform.seed_everything(0)
|
|
||||||
|
|
||||||
device = f"{current_platform.device_name}:{local_rank}"
|
|
||||||
current_platform.set_device(device)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
|
|
||||||
update_environment_variables({
|
|
||||||
'RANK': str(local_rank),
|
|
||||||
'LOCAL_RANK': str(local_rank),
|
|
||||||
'WORLD_SIZE': str(world_size),
|
|
||||||
'MASTER_ADDR': 'localhost',
|
|
||||||
'MASTER_PORT': str(master_port),
|
|
||||||
})
|
|
||||||
|
|
||||||
# initialize distributed
|
|
||||||
init_distributed_environment()
|
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
||||||
|
|
||||||
# Create a test input tensor
|
|
||||||
image_input = torch.randn(batch_size, 3, 224, 224)
|
|
||||||
|
|
||||||
# Create a simple linear model
|
|
||||||
vision_model = SimpleLinearModel()
|
|
||||||
|
|
||||||
# Run the model directly on the full input
|
|
||||||
with torch.inference_mode():
|
|
||||||
direct_output = vision_model(image_input)
|
|
||||||
|
|
||||||
# Run the model through the sharded function
|
|
||||||
with torch.inference_mode():
|
|
||||||
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
|
|
||||||
|
|
||||||
# Check that the world size is set up correctly
|
|
||||||
assert get_tensor_model_parallel_world_size() == world_size
|
|
||||||
|
|
||||||
# Check that the outputs have the same shape
|
|
||||||
assert direct_output.shape == sharded_output.shape
|
|
||||||
|
|
||||||
# Check that the outputs are close (they should be identical)
|
|
||||||
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
|
|
||||||
"expected_grouped_sizes_per_gpu,test_description",
|
|
||||||
[
|
|
||||||
# Empty input
|
|
||||||
([], 2, [], [0, 0], [0, 0], "empty input"),
|
|
||||||
|
|
||||||
# Fewer samples than GPUs
|
|
||||||
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
|
|
||||||
], "fewer samples than GPUs"),
|
|
||||||
|
|
||||||
# Single GPU
|
|
||||||
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
|
|
||||||
|
|
||||||
# Balanced assignment
|
|
||||||
([100, 100, 100, 100
|
|
||||||
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
|
|
||||||
|
|
||||||
# Unbalanced sizes - this one is trickier since the algorithm is greedy
|
|
||||||
([1000, 100, 200, 50], 2, [0, 2, 1, 3
|
|
||||||
], [1, 3], [1000, 350], "unbalanced sizes"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_get_load_balance_assignment_cases(sizes, num_gpus,
|
|
||||||
expected_shuffle_indices,
|
|
||||||
expected_gpu_sample_counts,
|
|
||||||
expected_grouped_sizes_per_gpu,
|
|
||||||
test_description):
|
|
||||||
"""Test get_load_balance_assignment with various input cases."""
|
|
||||||
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
|
|
||||||
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
|
|
||||||
|
|
||||||
# Common assertions for all cases
|
|
||||||
assert len(shuffle_indices) == len(sizes)
|
|
||||||
assert len(gpu_sample_counts) == num_gpus
|
|
||||||
assert len(grouped_sizes_per_gpu) == num_gpus
|
|
||||||
assert sum(gpu_sample_counts) == len(sizes)
|
|
||||||
|
|
||||||
assert shuffle_indices == expected_shuffle_indices
|
|
||||||
|
|
||||||
assert gpu_sample_counts == expected_gpu_sample_counts
|
|
||||||
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleMRopeVisionModel(torch.nn.Module):
|
|
||||||
"""A simple vision model for testing mrope functionality."""
|
|
||||||
|
|
||||||
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
|
|
||||||
super().__init__()
|
|
||||||
self.spatial_merge_size = spatial_merge_size
|
|
||||||
self.out_hidden_size = out_hidden_size
|
|
||||||
self.linear = torch.nn.Linear(768, out_hidden_size)
|
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor,
|
|
||||||
grid_thw_list: list[list[int]]):
|
|
||||||
"""Simple forward pass that simulates spatial merging."""
|
|
||||||
# Apply linear transformation
|
|
||||||
embeddings = self.linear(pixel_values)
|
|
||||||
|
|
||||||
# Simulate spatial merging by reducing the number of patches
|
|
||||||
merge_factor = self.spatial_merge_size * self.spatial_merge_size
|
|
||||||
|
|
||||||
# Group patches and merge spatially
|
|
||||||
merged_embeddings = []
|
|
||||||
start_idx = 0
|
|
||||||
|
|
||||||
for grid_thw in grid_thw_list:
|
|
||||||
num_patches = math.prod(grid_thw)
|
|
||||||
end_idx = start_idx + num_patches
|
|
||||||
|
|
||||||
# Get patches for this image
|
|
||||||
image_patches = embeddings[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Simulate spatial merging by averaging groups of patches
|
|
||||||
merged_patches = num_patches // merge_factor
|
|
||||||
if merged_patches > 0:
|
|
||||||
# Reshape and average to simulate merging
|
|
||||||
reshaped = image_patches[:merged_patches * merge_factor].view(
|
|
||||||
merged_patches, merge_factor, -1)
|
|
||||||
merged = reshaped.mean(dim=1)
|
|
||||||
merged_embeddings.append(merged)
|
|
||||||
|
|
||||||
start_idx = end_idx
|
|
||||||
|
|
||||||
if merged_embeddings:
|
|
||||||
return torch.cat(merged_embeddings, dim=0)
|
|
||||||
else:
|
|
||||||
return torch.empty((0, self.out_hidden_size),
|
|
||||||
device=pixel_values.device,
|
|
||||||
dtype=pixel_values.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"batch_size",
|
|
||||||
[
|
|
||||||
1, # Single image
|
|
||||||
3, # Small batch
|
|
||||||
5, # Odd batch size (for testing padding)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
|
|
||||||
world_size = 2
|
|
||||||
# Launch processes
|
|
||||||
mp.spawn(
|
|
||||||
run_dp_sharded_mrope_vision_model_vs_direct,
|
|
||||||
args=(
|
|
||||||
world_size,
|
|
||||||
batch_size,
|
|
||||||
get_open_port(),
|
|
||||||
),
|
|
||||||
nprocs=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
|
||||||
world_size: int,
|
|
||||||
batch_size: int,
|
|
||||||
master_port: int):
|
|
||||||
"""
|
|
||||||
Test that run_dp_sharded_mrope_vision_model produces the same results as
|
|
||||||
calling the model directly.
|
|
||||||
"""
|
|
||||||
# Set random seed for reproducibility
|
|
||||||
current_platform.seed_everything(0)
|
|
||||||
device = f"{current_platform.device_name}:{local_rank}"
|
|
||||||
current_platform.set_device(device)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
|
|
||||||
update_environment_variables({
|
|
||||||
'RANK': str(local_rank),
|
|
||||||
'LOCAL_RANK': str(local_rank),
|
|
||||||
'WORLD_SIZE': str(world_size),
|
|
||||||
'MASTER_ADDR': 'localhost',
|
|
||||||
'MASTER_PORT': str(master_port),
|
|
||||||
})
|
|
||||||
|
|
||||||
# initialize distributed
|
|
||||||
init_distributed_environment()
|
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
grid_thw_list = []
|
|
||||||
pixel_values_list = []
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
# Varying image sizes for better testing
|
|
||||||
t, h, w = 1, 4 + i, 4 + i
|
|
||||||
grid_thw_list.append([t, h, w])
|
|
||||||
|
|
||||||
num_patches = t * h * w
|
|
||||||
# Create random pixel values for this image
|
|
||||||
image_pixels = torch.randn(num_patches, 768)
|
|
||||||
pixel_values_list.append(image_pixels)
|
|
||||||
|
|
||||||
# Concatenate all pixel values
|
|
||||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
|
||||||
|
|
||||||
# Create a simple mrope vision model
|
|
||||||
vision_model = SimpleMRopeVisionModel()
|
|
||||||
|
|
||||||
# Run the model directly on the full input (only on rank 0)
|
|
||||||
if local_rank == 0:
|
|
||||||
with torch.inference_mode():
|
|
||||||
direct_output = vision_model(pixel_values, grid_thw_list)
|
|
||||||
|
|
||||||
# Run the model through the sharded function
|
|
||||||
with torch.inference_mode():
|
|
||||||
sharded_output = run_dp_sharded_mrope_vision_model(vision_model,
|
|
||||||
pixel_values,
|
|
||||||
grid_thw_list,
|
|
||||||
rope_type="rope_3d")
|
|
||||||
sharded_output = torch.cat(sharded_output, dim=0)
|
|
||||||
|
|
||||||
# Check that the world size is set up correctly
|
|
||||||
assert get_tensor_model_parallel_world_size() == world_size
|
|
||||||
|
|
||||||
# Compare outputs (only on rank 0)
|
|
||||||
if local_rank == 0:
|
|
||||||
# Check that the outputs have the same shape
|
|
||||||
assert direct_output.shape == sharded_output.shape
|
|
||||||
# Check that the outputs are close (they should be identical)
|
|
||||||
assert torch.allclose(direct_output,
|
|
||||||
sharded_output,
|
|
||||||
rtol=1e-5,
|
|
||||||
atol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
|
||||||
def test_run_dp_sharded_mrope_vision_model_empty_input():
|
|
||||||
world_size = 2
|
|
||||||
mp.spawn(
|
|
||||||
run_dp_sharded_mrope_vision_model_empty_input_worker,
|
|
||||||
args=(world_size, get_open_port()),
|
|
||||||
nprocs=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
|
||||||
local_rank: int, world_size: int, master_port: int):
|
|
||||||
"""Test run_dp_sharded_mrope_vision_model with empty input."""
|
|
||||||
# Set up distributed environment
|
|
||||||
device = f"{current_platform.device_name}:{local_rank}"
|
|
||||||
current_platform.set_device(device)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
|
|
||||||
update_environment_variables({
|
|
||||||
'RANK': str(local_rank),
|
|
||||||
'LOCAL_RANK': str(local_rank),
|
|
||||||
'WORLD_SIZE': str(world_size),
|
|
||||||
'MASTER_ADDR': 'localhost',
|
|
||||||
'MASTER_PORT': str(master_port),
|
|
||||||
})
|
|
||||||
|
|
||||||
init_distributed_environment()
|
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
||||||
|
|
||||||
# Create empty inputs
|
|
||||||
pixel_values = torch.empty((0, 768))
|
|
||||||
grid_thw_list: list[list[int]] = []
|
|
||||||
|
|
||||||
vision_model = SimpleMRopeVisionModel()
|
|
||||||
|
|
||||||
# Should handle empty input gracefully
|
|
||||||
with torch.inference_mode():
|
|
||||||
output = run_dp_sharded_mrope_vision_model(vision_model,
|
|
||||||
pixel_values,
|
|
||||||
grid_thw_list,
|
|
||||||
rope_type="rope_3d")
|
|
||||||
|
|
||||||
assert len(output) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=4)
|
|
||||||
def test_run_dp_sharded_mrope_vision_model_uneven_load():
|
|
||||||
world_size = 4
|
|
||||||
mp.spawn(
|
|
||||||
run_dp_sharded_mrope_vision_model_uneven_load_worker,
|
|
||||||
args=(world_size, get_open_port()),
|
|
||||||
nprocs=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
|
||||||
local_rank: int, world_size: int, master_port: int):
|
|
||||||
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
|
|
||||||
# Set up distributed environment
|
|
||||||
current_platform.seed_everything(123)
|
|
||||||
device = f"{current_platform.device_name}:{local_rank}"
|
|
||||||
current_platform.set_device(device)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
|
|
||||||
update_environment_variables({
|
|
||||||
'RANK': str(local_rank),
|
|
||||||
'LOCAL_RANK': str(local_rank),
|
|
||||||
'WORLD_SIZE': str(world_size),
|
|
||||||
'MASTER_ADDR': 'localhost',
|
|
||||||
'MASTER_PORT': str(master_port),
|
|
||||||
})
|
|
||||||
|
|
||||||
init_distributed_environment()
|
|
||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
||||||
|
|
||||||
# Create images with very different sizes
|
|
||||||
grid_thw_list = [
|
|
||||||
[1, 2, 2], # Small: 4 patches
|
|
||||||
[1, 8, 8], # Large: 64 patches
|
|
||||||
[1, 3, 3], # Medium: 9 patches
|
|
||||||
]
|
|
||||||
|
|
||||||
pixel_values_list = []
|
|
||||||
for grid_thw in grid_thw_list:
|
|
||||||
num_patches = math.prod(grid_thw)
|
|
||||||
image_pixels = torch.randn(num_patches, 768)
|
|
||||||
pixel_values_list.append(image_pixels)
|
|
||||||
|
|
||||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
|
||||||
vision_model = SimpleMRopeVisionModel()
|
|
||||||
|
|
||||||
# Should handle uneven distribution without errors
|
|
||||||
with torch.inference_mode():
|
|
||||||
output_tuple = run_dp_sharded_mrope_vision_model(vision_model,
|
|
||||||
pixel_values,
|
|
||||||
grid_thw_list,
|
|
||||||
rope_type="rope_3d")
|
|
||||||
|
|
||||||
# Verify output shape is reasonable
|
|
||||||
merge_factor = vision_model.spatial_merge_size**2
|
|
||||||
expected_output_patches = list(
|
|
||||||
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
|
|
||||||
|
|
||||||
for i, output in enumerate(output_tuple):
|
|
||||||
assert output.shape[0] == expected_output_patches[i]
|
|
||||||
assert output.shape[1] == vision_model.out_hidden_size
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
|
|
||||||
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
|
|
||||||
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
|
|
||||||
device = current_platform.device_type
|
|
||||||
|
|
||||||
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
|
|
||||||
pixel_values_list = []
|
|
||||||
|
|
||||||
for grid_thw in grid_thw_list:
|
|
||||||
num_patches = math.prod(grid_thw)
|
|
||||||
image_pixels = torch.randn(num_patches, 768, device=device)
|
|
||||||
pixel_values_list.append(image_pixels)
|
|
||||||
|
|
||||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
|
||||||
vision_model = SimpleMRopeVisionModel(
|
|
||||||
spatial_merge_size=spatial_merge_size).to(device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
output = vision_model(pixel_values, grid_thw_list)
|
|
||||||
|
|
||||||
# Verify output dimensions based on spatial merging
|
|
||||||
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
|
|
||||||
merge_factor = spatial_merge_size**2
|
|
||||||
expected_output_patches = total_patches // merge_factor
|
|
||||||
|
|
||||||
assert output.shape[0] == expected_output_patches
|
|
||||||
assert output.shape[1] == vision_model.out_hidden_size
|
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
@@ -83,7 +82,7 @@ from .qwen2_vl import (_create_qwen2vl_field_factory,
|
|||||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
|
||||||
|
from .vision import run_dp_sharded_vision_model
|
||||||
|
|
||||||
|
|
||||||
class Idefics2VisionEmbeddings(nn.Module):
|
class Idefics2VisionEmbeddings(nn.Module):
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
|
||||||
|
from .vision import run_dp_sharded_vision_model
|
||||||
|
|
||||||
NORM2FN = {
|
NORM2FN = {
|
||||||
'rms_norm': RMSNorm,
|
'rms_norm': RMSNorm,
|
||||||
|
|||||||
@@ -76,13 +76,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
||||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
|
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
|
||||||
|
from .vision import run_dp_sharded_mrope_vision_model
|
||||||
|
|
||||||
|
|
||||||
# For dummy input only
|
# For dummy input only
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
@@ -58,6 +57,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|||||||
from .llama4 import Llama4ForCausalLM
|
from .llama4 import Llama4ForCausalLM
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
from .vision import run_dp_sharded_vision_model
|
||||||
|
|
||||||
|
|
||||||
class Llama4ImagePatchInputs(TensorSchema):
|
class Llama4ImagePatchInputs(TensorSchema):
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
@@ -74,7 +73,7 @@ from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
|||||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
@@ -78,7 +77,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
|
|||||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ from .qwen2_vl import Qwen2VLProcessingInfo
|
|||||||
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
|
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -1214,8 +1214,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
from vllm.multimodal.utils import (
|
|
||||||
run_dp_sharded_mrope_vision_model)
|
|
||||||
return run_dp_sharded_mrope_vision_model(self.visual,
|
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
grid_thw_list,
|
grid_thw_list,
|
||||||
@@ -1245,8 +1243,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||||
self.visual.dtype)
|
self.visual.dtype)
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
from vllm.multimodal.utils import (
|
|
||||||
run_dp_sharded_mrope_vision_model)
|
|
||||||
return run_dp_sharded_mrope_vision_model(self.visual,
|
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||||
pixel_values_videos,
|
pixel_values_videos,
|
||||||
grid_thw_list,
|
grid_thw_list,
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
@@ -40,6 +39,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
from .vision import run_dp_sharded_vision_model
|
||||||
|
|
||||||
|
|
||||||
class Step3VLImagePixelInputs(TypedDict):
|
class Step3VLImagePixelInputs(TypedDict):
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
|
||||||
@@ -123,3 +128,277 @@ def resolve_visual_encoder_outputs(
|
|||||||
if post_layer_norm is not None and uses_last_layer:
|
if post_layer_norm is not None and uses_last_layer:
|
||||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||||
return torch.cat(hs_pool, dim=-1)
|
return torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||||
|
vision_model: torch.nn.Module) -> torch.Tensor:
|
||||||
|
"""Run a vision model with data parallelism (DP) sharding. The function
|
||||||
|
will shard the input image tensor on the first dimension and run the vision
|
||||||
|
model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_input (torch.Tensor): Image input tensor.
|
||||||
|
vision_model (torch.nn.Module): Vision model.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output image embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_chunks = image_input.shape[0]
|
||||||
|
mp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
||||||
|
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
||||||
|
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
||||||
|
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
||||||
|
rank = get_tensor_model_parallel_rank()
|
||||||
|
image_input_per_rank = image_input_padded[rank *
|
||||||
|
num_chunks_per_rank:(rank + 1) *
|
||||||
|
num_chunks_per_rank, ...]
|
||||||
|
|
||||||
|
vision_embeddings = vision_model(image_input_per_rank)
|
||||||
|
# Ensure tensor is contiguous before all_gather
|
||||||
|
vision_embeddings = vision_embeddings.contiguous()
|
||||||
|
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||||
|
dim=0)
|
||||||
|
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||||
|
return vision_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def get_load_balance_assignment(
|
||||||
|
sizes: list[int],
|
||||||
|
num_gpus: int = 2,
|
||||||
|
) -> tuple[list[int], list[int], list[int]]:
|
||||||
|
"""
|
||||||
|
Generate load balancing assignment and metadata
|
||||||
|
for distributing data across GPUs.
|
||||||
|
The load is determined by the total image sizes,
|
||||||
|
not the number of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sizes: The size of each image
|
||||||
|
num_gpus: Number of GPUs to balance across
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
shuffle_indices:
|
||||||
|
Indices to reorder data for balanced loading
|
||||||
|
gpu_sample_counts:
|
||||||
|
Number of samples assigned to each GPU
|
||||||
|
grouped_sizes_per_gpu:
|
||||||
|
Total size assigned to each GPU
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
sizes = [1000, 100, 200, 50]
|
||||||
|
num_gpus=2
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_samples = len(sizes)
|
||||||
|
|
||||||
|
# Handle edge cases
|
||||||
|
if n_samples == 0:
|
||||||
|
return [], [0] * num_gpus, [0] * num_gpus
|
||||||
|
|
||||||
|
# Use greedy algorithm - balance by total size, not sample count
|
||||||
|
gpu_assignments = [list[int]() for _ in range(num_gpus)]
|
||||||
|
gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count
|
||||||
|
|
||||||
|
# Sort indices by size (largest first for better load balancing)
|
||||||
|
# sizes = [1000, 100, 200, 50]
|
||||||
|
# large_to_small_indices = [0, 2, 1, 3]
|
||||||
|
large_to_small_indices = sorted(range(n_samples),
|
||||||
|
key=lambda i: sizes[i],
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
for idx in large_to_small_indices:
|
||||||
|
# Find GPU with minimum current load (by total size)
|
||||||
|
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
|
||||||
|
gpu_assignments[min_gpu].append(idx)
|
||||||
|
gpu_loads[min_gpu] += sizes[idx]
|
||||||
|
|
||||||
|
# Create shuffle indices and counts
|
||||||
|
shuffle_indices = list[int]()
|
||||||
|
gpu_sample_counts = list[int]()
|
||||||
|
for gpu_id in range(num_gpus):
|
||||||
|
# GPU_0 = [1000] = [0]
|
||||||
|
# GPU_1 = [200, 100, 50] = [2, 1, 3]
|
||||||
|
# shuffle_indices = [0, 2, 1, 3]
|
||||||
|
shuffle_indices.extend(gpu_assignments[gpu_id])
|
||||||
|
# GPU_0 = [1]
|
||||||
|
# GPU_1 = [3]
|
||||||
|
# gpu_sample_counts = [1, 3]
|
||||||
|
gpu_sample_counts.append(len(gpu_assignments[gpu_id]))
|
||||||
|
|
||||||
|
return (shuffle_indices, gpu_sample_counts, gpu_loads)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dp_sharded_mrope_vision_model(
|
||||||
|
vision_model: torch.nn.Module,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
grid_thw_list: list[list[int]],
|
||||||
|
*,
|
||||||
|
rope_type: Literal["rope_3d", "rope_2d"],
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""Run a vision model with data parallelism (DP) sharding.
|
||||||
|
The function will shard the input image tensor on the
|
||||||
|
first dimension and run the vision model.
|
||||||
|
This function is used to run the vision model with mrope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_model (torch.nn.Module): Vision model.
|
||||||
|
pixel_values (torch.Tensor): Image/Video input tensor.
|
||||||
|
grid_thw_list: List of grid dimensions for each image
|
||||||
|
rope_type: Type of rope used in the vision model.
|
||||||
|
Different rope types have different dimension to do ViT.
|
||||||
|
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
|
||||||
|
"rope_2d" for 2D rope (e.g., Kimi-VL)
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output image embeddings
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
vision_model.out_hidden_size = 64
|
||||||
|
vision_model.spatial_merge_size = 2
|
||||||
|
pixel_values.shape = (1350, channel)
|
||||||
|
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
|
||||||
|
tp_size=2
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# GPU_0 tp_rank_local = 0
|
||||||
|
# GPU_1 tp_rank_local = 1
|
||||||
|
tp_rank_local = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
# patches_per_image = [1000, 100, 200, 50]
|
||||||
|
patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
|
||||||
|
# patches_per_image = [0, 1000, 1100, 1300, 1350]
|
||||||
|
cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]
|
||||||
|
|
||||||
|
# Get load balancing assignment with all metadata
|
||||||
|
# image_to_tp_rank = [0, 2, 1, 3]
|
||||||
|
# gpu_sample_counts = [1, 3]
|
||||||
|
# grouped_pixel_values_len = [1000, 350]
|
||||||
|
(image_to_tp_rank, gpu_sample_counts,
|
||||||
|
grouped_pixel_values_len) = get_load_balance_assignment(
|
||||||
|
patches_per_image, tp_size)
|
||||||
|
|
||||||
|
# cu_gpu_sample_counts = [0, 1, 4]
|
||||||
|
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
|
||||||
|
|
||||||
|
# GPU_0 image_idxs_local = [0]
|
||||||
|
# GPU_1 image_idxs_local = [2, 1, 3]
|
||||||
|
image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
|
||||||
|
cum_gpu_sample_counts[tp_rank_local +
|
||||||
|
1]]
|
||||||
|
|
||||||
|
# Get the pixel values for the local images based on the image_idxs_local
|
||||||
|
if len(image_idxs_local) > 0:
|
||||||
|
pixel_values_local = torch.cat([
|
||||||
|
pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
|
||||||
|
for i in image_idxs_local
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Handle case where this rank has no images
|
||||||
|
pixel_values_local = torch.empty((0, pixel_values.shape[1]),
|
||||||
|
device=pixel_values.device,
|
||||||
|
dtype=pixel_values.dtype)
|
||||||
|
# embed_dim_reduction_factor = 2 * 2
|
||||||
|
if rope_type == "rope_2d":
|
||||||
|
embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
|
||||||
|
vision_model.merge_kernel_size[1])
|
||||||
|
else:
|
||||||
|
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
|
||||||
|
vision_model.spatial_merge_size)
|
||||||
|
|
||||||
|
# Find the max length across all ranks
|
||||||
|
# The output embedding of every DP rank has to be
|
||||||
|
# padded to this length for tensor_model_parallel_all_gather
|
||||||
|
# to work
|
||||||
|
max_len_per_rank = max(
|
||||||
|
grouped_pixel_values_len) // embed_dim_reduction_factor
|
||||||
|
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
|
||||||
|
|
||||||
|
# Run the vision model on the local pixel_values_local
|
||||||
|
if rope_type == "rope_2d":
|
||||||
|
if pixel_values_local.shape[0] > 0:
|
||||||
|
image_embeds_local = vision_model(
|
||||||
|
pixel_values_local, torch.tensor(local_grid_thw_list))
|
||||||
|
if isinstance(image_embeds_local, list):
|
||||||
|
image_embeds_local = torch.cat(image_embeds_local, dim=0)
|
||||||
|
else:
|
||||||
|
out_dim = getattr(vision_model.config, "hidden_size", None)
|
||||||
|
image_embeds_local = torch.empty(
|
||||||
|
(0, embed_dim_reduction_factor, out_dim),
|
||||||
|
device=pixel_values.device,
|
||||||
|
dtype=pixel_values.dtype)
|
||||||
|
else:
|
||||||
|
if pixel_values_local.shape[0] > 0:
|
||||||
|
image_embeds_local = vision_model(pixel_values_local,
|
||||||
|
local_grid_thw_list)
|
||||||
|
else:
|
||||||
|
# Handle empty case
|
||||||
|
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
|
||||||
|
device=pixel_values.device,
|
||||||
|
dtype=pixel_values.dtype)
|
||||||
|
|
||||||
|
# Pad the output based on max_len_per_rank
|
||||||
|
# for tensor_model_parallel_all_gather to work
|
||||||
|
current_len = image_embeds_local.shape[0]
|
||||||
|
if current_len < max_len_per_rank:
|
||||||
|
padding_size = max_len_per_rank - current_len
|
||||||
|
if rope_type == "rope_2d":
|
||||||
|
padding = torch.empty((padding_size, image_embeds_local.shape[1],
|
||||||
|
image_embeds_local.shape[2]),
|
||||||
|
dtype=image_embeds_local.dtype,
|
||||||
|
device=image_embeds_local.device)
|
||||||
|
else:
|
||||||
|
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
|
||||||
|
dtype=image_embeds_local.dtype,
|
||||||
|
device=image_embeds_local.device)
|
||||||
|
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
|
||||||
|
dim=0)
|
||||||
|
else:
|
||||||
|
image_embeds_local_padded = image_embeds_local
|
||||||
|
|
||||||
|
# Do all_gather to collect embeddings from all ranks
|
||||||
|
gathered_embeds = tensor_model_parallel_all_gather(
|
||||||
|
image_embeds_local_padded, dim=0)
|
||||||
|
|
||||||
|
# Remove padding and reconstruct per-rank embeddings
|
||||||
|
rank_embeddings = list[torch.Tensor]()
|
||||||
|
for rank in range(tp_size):
|
||||||
|
start_idx = rank * max_len_per_rank
|
||||||
|
end_idx = start_idx + (grouped_pixel_values_len[rank] //
|
||||||
|
embed_dim_reduction_factor)
|
||||||
|
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
|
||||||
|
|
||||||
|
patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
|
||||||
|
for patch_size in patches_per_image]
|
||||||
|
|
||||||
|
# Reconstruct embeddings in the original order
|
||||||
|
original_order_embeddings = [None] * len(grid_thw_list)
|
||||||
|
current_idx = 0
|
||||||
|
for rank in range(tp_size):
|
||||||
|
count = gpu_sample_counts[rank]
|
||||||
|
if count > 0:
|
||||||
|
# Get images assigned to this rank in shuffled order
|
||||||
|
# GPU_0 = image_idxs_local [0]
|
||||||
|
# GPU_1 = image_idxs_local [2, 1, 3]
|
||||||
|
rank_images = image_to_tp_rank[current_idx:current_idx + count]
|
||||||
|
|
||||||
|
rank_embed = rank_embeddings[rank]
|
||||||
|
# Split rank embeddings back to individual images
|
||||||
|
embed_start = 0
|
||||||
|
for img_idx in rank_images:
|
||||||
|
img_patches = patches_per_output_image[img_idx]
|
||||||
|
original_order_embeddings[img_idx] = rank_embed[
|
||||||
|
embed_start:embed_start + img_patches]
|
||||||
|
embed_start += img_patches
|
||||||
|
current_idx += count
|
||||||
|
out_embeddings = tuple(embed for embed in original_order_embeddings
|
||||||
|
if embed is not None)
|
||||||
|
assert len(out_embeddings) == len(
|
||||||
|
original_order_embeddings), "Found unassigned embeddings"
|
||||||
|
return out_embeddings
|
||||||
|
|||||||
@@ -3,13 +3,11 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
import itertools
|
|
||||||
import math
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||||
from urllib.parse import ParseResult, urlparse
|
from urllib.parse import ParseResult, urlparse
|
||||||
from urllib.request import url2pathname
|
from urllib.request import url2pathname
|
||||||
|
|
||||||
@@ -21,9 +19,6 @@ from typing_extensions import deprecated
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.connections import HTTPConnection, global_http_connection
|
from vllm.connections import HTTPConnection, global_http_connection
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_gather)
|
|
||||||
|
|
||||||
from .audio import AudioMediaIO
|
from .audio import AudioMediaIO
|
||||||
from .base import MediaIO
|
from .base import MediaIO
|
||||||
@@ -33,12 +28,10 @@ from .video import VideoMediaIO
|
|||||||
_M = TypeVar("_M")
|
_M = TypeVar("_M")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .inputs import (BatchedTensorInputs, MultiModalKwargs,
|
from .inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
||||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
MultiModalKwargsItems, MultiModalPlaceholderDict)
|
||||||
MultiModalPlaceholderDict)
|
|
||||||
else:
|
else:
|
||||||
BatchedTensorInputs = Any
|
BatchedTensorInputs = Any
|
||||||
MultiModalKwargs = Any
|
|
||||||
MultiModalKwargsItem = Any
|
MultiModalKwargsItem = Any
|
||||||
MultiModalKwargsItems = Any
|
MultiModalKwargsItems = Any
|
||||||
MultiModalPlaceholderDict = Any
|
MultiModalPlaceholderDict = Any
|
||||||
@@ -93,7 +86,7 @@ class MediaConnector:
|
|||||||
self,
|
self,
|
||||||
url_spec: ParseResult,
|
url_spec: ParseResult,
|
||||||
media_io: MediaIO[_M],
|
media_io: MediaIO[_M],
|
||||||
) -> _M:
|
) -> _M: # type: ignore[type-var]
|
||||||
data_spec, data = url_spec.path.split(",", 1)
|
data_spec, data = url_spec.path.split(",", 1)
|
||||||
media_type, data_type = data_spec.split(";", 1)
|
media_type, data_type = data_spec.split(";", 1)
|
||||||
|
|
||||||
@@ -107,7 +100,7 @@ class MediaConnector:
|
|||||||
self,
|
self,
|
||||||
url_spec: ParseResult,
|
url_spec: ParseResult,
|
||||||
media_io: MediaIO[_M],
|
media_io: MediaIO[_M],
|
||||||
) -> _M:
|
) -> _M: # type: ignore[type-var]
|
||||||
allowed_local_media_path = self.allowed_local_media_path
|
allowed_local_media_path = self.allowed_local_media_path
|
||||||
if allowed_local_media_path is None:
|
if allowed_local_media_path is None:
|
||||||
raise RuntimeError("Cannot load local files without "
|
raise RuntimeError("Cannot load local files without "
|
||||||
@@ -127,7 +120,7 @@ class MediaConnector:
|
|||||||
media_io: MediaIO[_M],
|
media_io: MediaIO[_M],
|
||||||
*,
|
*,
|
||||||
fetch_timeout: Optional[int] = None,
|
fetch_timeout: Optional[int] = None,
|
||||||
) -> _M:
|
) -> _M: # type: ignore[type-var]
|
||||||
url_spec = urlparse(url)
|
url_spec = urlparse(url)
|
||||||
|
|
||||||
if url_spec.scheme.startswith("http"):
|
if url_spec.scheme.startswith("http"):
|
||||||
@@ -434,280 +427,6 @@ def group_mm_kwargs_by_modality(
|
|||||||
yield modality, len(items_lst), mm_kwargs_group
|
yield modality, len(items_lst), mm_kwargs_group
|
||||||
|
|
||||||
|
|
||||||
def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
|
||||||
vision_model: torch.nn.Module) -> torch.Tensor:
|
|
||||||
"""Run a vision model with data parallelism (DP) sharding. The function
|
|
||||||
will shard the input image tensor on the first dimension and run the vision
|
|
||||||
model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_input (torch.Tensor): Image input tensor.
|
|
||||||
vision_model (torch.nn.Module): Vision model.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output image embeddings
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_chunks = image_input.shape[0]
|
|
||||||
mp_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
|
||||||
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
|
||||||
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
|
||||||
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
image_input_per_rank = image_input_padded[rank *
|
|
||||||
num_chunks_per_rank:(rank + 1) *
|
|
||||||
num_chunks_per_rank, ...]
|
|
||||||
|
|
||||||
vision_embeddings = vision_model(image_input_per_rank)
|
|
||||||
# Ensure tensor is contiguous before all_gather
|
|
||||||
vision_embeddings = vision_embeddings.contiguous()
|
|
||||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
|
||||||
dim=0)
|
|
||||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
|
||||||
return vision_embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def get_load_balance_assignment(
|
|
||||||
sizes: list[int],
|
|
||||||
num_gpus: int = 2,
|
|
||||||
) -> tuple[list[int], list[int], list[int]]:
|
|
||||||
"""
|
|
||||||
Generate load balancing assignment and metadata
|
|
||||||
for distributing data across GPUs.
|
|
||||||
The load is determined by the total image sizes,
|
|
||||||
not the number of images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sizes: The size of each image
|
|
||||||
num_gpus: Number of GPUs to balance across
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
shuffle_indices:
|
|
||||||
Indices to reorder data for balanced loading
|
|
||||||
gpu_sample_counts:
|
|
||||||
Number of samples assigned to each GPU
|
|
||||||
grouped_sizes_per_gpu:
|
|
||||||
Total size assigned to each GPU
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```
|
|
||||||
sizes = [1000, 100, 200, 50]
|
|
||||||
num_gpus=2
|
|
||||||
```
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
n_samples = len(sizes)
|
|
||||||
|
|
||||||
# Handle edge cases
|
|
||||||
if n_samples == 0:
|
|
||||||
return [], [0] * num_gpus, [0] * num_gpus
|
|
||||||
|
|
||||||
# Use greedy algorithm - balance by total size, not sample count
|
|
||||||
gpu_assignments = [list[int]() for _ in range(num_gpus)]
|
|
||||||
gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count
|
|
||||||
|
|
||||||
# Sort indices by size (largest first for better load balancing)
|
|
||||||
# sizes = [1000, 100, 200, 50]
|
|
||||||
# large_to_small_indices = [0, 2, 1, 3]
|
|
||||||
large_to_small_indices = sorted(range(n_samples),
|
|
||||||
key=lambda i: sizes[i],
|
|
||||||
reverse=True)
|
|
||||||
|
|
||||||
for idx in large_to_small_indices:
|
|
||||||
# Find GPU with minimum current load (by total size)
|
|
||||||
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
|
|
||||||
gpu_assignments[min_gpu].append(idx)
|
|
||||||
gpu_loads[min_gpu] += sizes[idx]
|
|
||||||
|
|
||||||
# Create shuffle indices and counts
|
|
||||||
shuffle_indices = list[int]()
|
|
||||||
gpu_sample_counts = list[int]()
|
|
||||||
for gpu_id in range(num_gpus):
|
|
||||||
# GPU_0 = [1000] = [0]
|
|
||||||
# GPU_1 = [200, 100, 50] = [2, 1, 3]
|
|
||||||
# shuffle_indices = [0, 2, 1, 3]
|
|
||||||
shuffle_indices.extend(gpu_assignments[gpu_id])
|
|
||||||
# GPU_0 = [1]
|
|
||||||
# GPU_1 = [3]
|
|
||||||
# gpu_sample_counts = [1, 3]
|
|
||||||
gpu_sample_counts.append(len(gpu_assignments[gpu_id]))
|
|
||||||
|
|
||||||
return (shuffle_indices, gpu_sample_counts, gpu_loads)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dp_sharded_mrope_vision_model(
|
|
||||||
vision_model: torch.nn.Module,
|
|
||||||
pixel_values: torch.Tensor,
|
|
||||||
grid_thw_list: list[list[int]],
|
|
||||||
*,
|
|
||||||
rope_type: Literal["rope_3d", "rope_2d"],
|
|
||||||
) -> tuple[torch.Tensor, ...]:
|
|
||||||
"""Run a vision model with data parallelism (DP) sharding.
|
|
||||||
The function will shard the input image tensor on the
|
|
||||||
first dimension and run the vision model.
|
|
||||||
This function is used to run the vision model with mrope.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vision_model (torch.nn.Module): Vision model.
|
|
||||||
pixel_values (torch.Tensor): Image/Video input tensor.
|
|
||||||
grid_thw_list: List of grid dimensions for each image
|
|
||||||
rope_type: Type of rope used in the vision model.
|
|
||||||
Different rope types have different dimension to do ViT.
|
|
||||||
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
|
|
||||||
"rope_2d" for 2D rope (e.g., Kimi-VL)
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output image embeddings
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```
|
|
||||||
vision_model.out_hidden_size = 64
|
|
||||||
vision_model.spatial_merge_size = 2
|
|
||||||
pixel_values.shape = (1350, channel)
|
|
||||||
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
|
|
||||||
tp_size=2
|
|
||||||
```
|
|
||||||
|
|
||||||
"""
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
# GPU_0 tp_rank_local = 0
|
|
||||||
# GPU_1 tp_rank_local = 1
|
|
||||||
tp_rank_local = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
# patches_per_image = [1000, 100, 200, 50]
|
|
||||||
patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
|
|
||||||
# patches_per_image = [0, 1000, 1100, 1300, 1350]
|
|
||||||
cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]
|
|
||||||
|
|
||||||
# Get load balancing assignment with all metadata
|
|
||||||
# image_to_tp_rank = [0, 2, 1, 3]
|
|
||||||
# gpu_sample_counts = [1, 3]
|
|
||||||
# grouped_pixel_values_len = [1000, 350]
|
|
||||||
(image_to_tp_rank, gpu_sample_counts,
|
|
||||||
grouped_pixel_values_len) = get_load_balance_assignment(
|
|
||||||
patches_per_image, tp_size)
|
|
||||||
|
|
||||||
# cu_gpu_sample_counts = [0, 1, 4]
|
|
||||||
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
|
|
||||||
|
|
||||||
# GPU_0 image_idxs_local = [0]
|
|
||||||
# GPU_1 image_idxs_local = [2, 1, 3]
|
|
||||||
image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
|
|
||||||
cum_gpu_sample_counts[tp_rank_local +
|
|
||||||
1]]
|
|
||||||
|
|
||||||
# Get the pixel values for the local images based on the image_idxs_local
|
|
||||||
if len(image_idxs_local) > 0:
|
|
||||||
pixel_values_local = torch.cat([
|
|
||||||
pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
|
|
||||||
for i in image_idxs_local
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
# Handle case where this rank has no images
|
|
||||||
pixel_values_local = torch.empty((0, pixel_values.shape[1]),
|
|
||||||
device=pixel_values.device,
|
|
||||||
dtype=pixel_values.dtype)
|
|
||||||
# embed_dim_reduction_factor = 2 * 2
|
|
||||||
if rope_type == "rope_2d":
|
|
||||||
embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] *
|
|
||||||
vision_model.merge_kernel_size[1])
|
|
||||||
else:
|
|
||||||
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
|
|
||||||
vision_model.spatial_merge_size)
|
|
||||||
|
|
||||||
# Find the max length across all ranks
|
|
||||||
# The output embedding of every DP rank has to be
|
|
||||||
# padded to this length for tensor_model_parallel_all_gather
|
|
||||||
# to work
|
|
||||||
max_len_per_rank = max(
|
|
||||||
grouped_pixel_values_len) // embed_dim_reduction_factor
|
|
||||||
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
|
|
||||||
|
|
||||||
# Run the vision model on the local pixel_values_local
|
|
||||||
if rope_type == "rope_2d":
|
|
||||||
if pixel_values_local.shape[0] > 0:
|
|
||||||
image_embeds_local = vision_model(
|
|
||||||
pixel_values_local, torch.tensor(local_grid_thw_list))
|
|
||||||
if isinstance(image_embeds_local, list):
|
|
||||||
image_embeds_local = torch.cat(image_embeds_local, dim=0)
|
|
||||||
else:
|
|
||||||
out_dim = getattr(vision_model.config, "hidden_size", None)
|
|
||||||
image_embeds_local = torch.empty(
|
|
||||||
(0, embed_dim_reduction_factor, out_dim),
|
|
||||||
device=pixel_values.device,
|
|
||||||
dtype=pixel_values.dtype)
|
|
||||||
else:
|
|
||||||
if pixel_values_local.shape[0] > 0:
|
|
||||||
image_embeds_local = vision_model(pixel_values_local,
|
|
||||||
local_grid_thw_list)
|
|
||||||
else:
|
|
||||||
# Handle empty case
|
|
||||||
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
|
|
||||||
device=pixel_values.device,
|
|
||||||
dtype=pixel_values.dtype)
|
|
||||||
|
|
||||||
# Pad the output based on max_len_per_rank
|
|
||||||
# for tensor_model_parallel_all_gather to work
|
|
||||||
current_len = image_embeds_local.shape[0]
|
|
||||||
if current_len < max_len_per_rank:
|
|
||||||
padding_size = max_len_per_rank - current_len
|
|
||||||
if rope_type == "rope_2d":
|
|
||||||
padding = torch.empty((padding_size, image_embeds_local.shape[1],
|
|
||||||
image_embeds_local.shape[2]),
|
|
||||||
dtype=image_embeds_local.dtype,
|
|
||||||
device=image_embeds_local.device)
|
|
||||||
else:
|
|
||||||
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
|
|
||||||
dtype=image_embeds_local.dtype,
|
|
||||||
device=image_embeds_local.device)
|
|
||||||
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
|
|
||||||
dim=0)
|
|
||||||
else:
|
|
||||||
image_embeds_local_padded = image_embeds_local
|
|
||||||
|
|
||||||
# Do all_gather to collect embeddings from all ranks
|
|
||||||
gathered_embeds = tensor_model_parallel_all_gather(
|
|
||||||
image_embeds_local_padded, dim=0)
|
|
||||||
|
|
||||||
# Remove padding and reconstruct per-rank embeddings
|
|
||||||
rank_embeddings = list[torch.Tensor]()
|
|
||||||
for rank in range(tp_size):
|
|
||||||
start_idx = rank * max_len_per_rank
|
|
||||||
end_idx = start_idx + (grouped_pixel_values_len[rank] //
|
|
||||||
embed_dim_reduction_factor)
|
|
||||||
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
|
|
||||||
|
|
||||||
patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
|
|
||||||
for patch_size in patches_per_image]
|
|
||||||
|
|
||||||
# Reconstruct embeddings in the original order
|
|
||||||
original_order_embeddings = [None] * len(grid_thw_list)
|
|
||||||
current_idx = 0
|
|
||||||
for rank in range(tp_size):
|
|
||||||
count = gpu_sample_counts[rank]
|
|
||||||
if count > 0:
|
|
||||||
# Get images assigned to this rank in shuffled order
|
|
||||||
# GPU_0 = image_idxs_local [0]
|
|
||||||
# GPU_1 = image_idxs_local [2, 1, 3]
|
|
||||||
rank_images = image_to_tp_rank[current_idx:current_idx + count]
|
|
||||||
|
|
||||||
rank_embed = rank_embeddings[rank]
|
|
||||||
# Split rank embeddings back to individual images
|
|
||||||
embed_start = 0
|
|
||||||
for img_idx in rank_images:
|
|
||||||
img_patches = patches_per_output_image[img_idx]
|
|
||||||
original_order_embeddings[img_idx] = rank_embed[
|
|
||||||
embed_start:embed_start + img_patches]
|
|
||||||
embed_start += img_patches
|
|
||||||
current_idx += count
|
|
||||||
out_embeddings = tuple(embed for embed in original_order_embeddings
|
|
||||||
if embed is not None)
|
|
||||||
assert len(out_embeddings) == len(
|
|
||||||
original_order_embeddings), "Found unassigned embeddings"
|
|
||||||
return out_embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_audio(
|
def fetch_audio(
|
||||||
audio_url: str,
|
audio_url: str,
|
||||||
audio_io_kwargs: Optional[dict[str, Any]] = None,
|
audio_io_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user