[Model] enable data parallel for Llama4 vision encoder (#18368)

Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Co-authored-by: yZhen <yZhen@fb.com>
Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
This commit is contained in:
jennyyyyzhen
2025-06-02 04:22:54 -07:00
committed by GitHub
parent 5b168b6d7a
commit ebb1ec9318
4 changed files with 214 additions and 68 deletions

View File

@@ -34,6 +34,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -49,6 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -84,23 +86,29 @@ class Llama4ImagePatchInputs(TypedDict):
class Llama4VisionMLP(nn.Module):
def __init__(self,
input_size: int,
intermediate_size: int,
output_size: int,
bias: bool,
output_activation: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
input_size: int,
intermediate_size: int,
output_size: int,
bias: bool,
output_activation: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.fc1 = ColumnParallelLinear(
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
input_size=input_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
self.fc2 = cls_fc2(
input_size=intermediate_size,
output_size=output_size,
bias=bias,
@@ -155,10 +163,12 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
int(channels / shuffle_ratio))
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)))
reshaped_tensor = reshaped_tensor.view(
batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)),
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1,
@@ -173,6 +183,7 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
@@ -186,7 +197,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
bias=config.multi_modal_projector_bias,
output_activation=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches,
@@ -201,10 +214,12 @@ class Llama4VisionAttention(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
@@ -217,22 +232,39 @@ class Llama4VisionAttention(nn.Module):
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
self.scaling)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
if use_data_parallel:
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
self.q_size + 2 * self.kv_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = ReplicatedLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
@@ -275,22 +307,29 @@ class Llama4VisionEncoderLayer(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.intermediate_size = config.intermediate_size
self.self_attn = Llama4VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
intermediate_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
output_activation=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.self_attn = Llama4VisionAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
)
self.mlp = Llama4VisionMLP(
input_size=config.hidden_size,
intermediate_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
output_activation=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
@@ -322,6 +361,7 @@ class Llama4VisionEncoder(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
@@ -330,6 +370,7 @@ class Llama4VisionEncoder(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
) for layer_idx in range(config.num_hidden_layers)
])
@@ -357,23 +398,33 @@ class Llama4VisionEncoder(nn.Module):
class Llama4UnfoldConvolution(nn.Module):
def __init__(self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
kernel_size = config.patch_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
stride=config.patch_size)
self.linear = ColumnParallelLinear(config.num_channels *
kernel_size[0] * kernel_size[1],
config.hidden_size,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.linear")
params = {
"input_size":
config.num_channels * kernel_size[0] * kernel_size[1],
"output_size": config.hidden_size,
"bias": False,
"quant_config": quant_config,
"prefix": f"{prefix}.linear",
}
if use_data_parallel:
cls = ReplicatedLinear
else:
cls = ColumnParallelLinear
params["gather_output"] = True
self.linear = cls(**params)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
@@ -389,6 +440,7 @@ class Llama4VisionModel(nn.Module):
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
@@ -403,7 +455,9 @@ class Llama4VisionModel(nn.Module):
self.patch_embedding = Llama4UnfoldConvolution(
config,
quant_config=quant_config,
prefix=f"{prefix}.patch_embedding")
prefix=f"{prefix}.patch_embedding",
use_data_parallel=use_data_parallel,
)
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
@@ -415,11 +469,18 @@ class Llama4VisionModel(nn.Module):
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
# encoders
self.model = Llama4VisionEncoder(config,
quant_config=quant_config,
prefix=f"{prefix}.model")
self.model = Llama4VisionEncoder(
config,
quant_config=quant_config,
prefix=f"{prefix}.model",
use_data_parallel=use_data_parallel,
)
self.vision_adapter = Llama4VisionPixelShuffleMLP(
config, quant_config, prefix=f"{prefix}.vision_adapter")
config,
quant_config,
prefix=f"{prefix}.vision_adapter",
use_data_parallel=use_data_parallel,
)
def forward(
self,
@@ -528,8 +589,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
vision_config = self.info.get_hf_config().vision_config
if processed_outputs.get("pixel_values") is not None:
assert "images" in mm_data, \
"images expected to be in mm_data when pixel_values is present"
assert (
"images" in mm_data
), "images expected to be in mm_data when pixel_values is present"
images = mm_data["images"]
parsed_images = (self._get_data_parser().parse_mm_data({
@@ -546,8 +608,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
get_best_fit(
(image.size[1], image.size[0]),
torch.tensor(possible_resolutions),
resize_to_max_canvas=image_processor.resize_to_max_canvas)
for image in parsed_images
resize_to_max_canvas=image_processor.resize_to_max_canvas,
) for image in parsed_images
]
# TODO tile height/width do not necessarily need to match
aspect_ratios = [(image_size[0] // tile_size,
@@ -659,13 +721,17 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_model = Llama4VisionModel(config.vision_config,
None,
prefix=maybe_prefix(
prefix, "vision_model"))
self.vision_model = Llama4VisionModel(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
self.multi_modal_projector = Llama4MultiModalProjector(
self.config,
None,
@@ -709,7 +775,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data)
# shard image input
if self.use_data_parallel:
vision_embeddings_flat = run_dp_sharded_vision_model(
flat_data, self.vision_model)
else:
vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.multi_modal_projector(
vision_embeddings_flat)
@@ -796,6 +868,30 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_prefix_weights(), get_other_weights()
def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
@@ -818,9 +914,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
if self.use_data_parallel:
other_weights = self._consolidate_qkv_weights(other_weights)
for name, loaded_weight in other_weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
if weight_name not in name or self.use_data_parallel:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]