[Model] enable data parallel for Llama4 vision encoder (#18368)
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Co-authored-by: yZhen <yZhen@fb.com> Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@@ -49,6 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@@ -84,23 +86,29 @@ class Llama4ImagePatchInputs(TypedDict):
|
||||
|
||||
class Llama4VisionMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
intermediate_size: int,
|
||||
output_size: int,
|
||||
bias: bool,
|
||||
output_activation: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
intermediate_size: int,
|
||||
output_size: int,
|
||||
bias: bool,
|
||||
output_activation: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
input_size=input_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
|
||||
self.fc2 = cls_fc2(
|
||||
input_size=intermediate_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
@@ -155,10 +163,12 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
|
||||
int(channels / shuffle_ratio))
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
reshaped_tensor = reshaped_tensor.view(batch_size,
|
||||
int(height * shuffle_ratio),
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / (shuffle_ratio**2)))
|
||||
reshaped_tensor = reshaped_tensor.view(
|
||||
batch_size,
|
||||
int(height * shuffle_ratio),
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / (shuffle_ratio**2)),
|
||||
)
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
output_tensor = reshaped_tensor.view(batch_size, -1,
|
||||
@@ -173,6 +183,7 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
||||
@@ -186,7 +197,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
bias=config.multi_modal_projector_bias,
|
||||
output_activation=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = pixel_shuffle(encoded_patches,
|
||||
@@ -201,10 +214,12 @@ class Llama4VisionAttention(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
@@ -217,22 +232,39 @@ class Llama4VisionAttention(nn.Module):
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
|
||||
self.scaling)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = ReplicatedLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
@@ -275,22 +307,29 @@ class Llama4VisionEncoderLayer(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = Llama4VisionAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=True,
|
||||
output_activation=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.self_attn = Llama4VisionAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.mlp = Llama4VisionMLP(
|
||||
input_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=True,
|
||||
output_activation=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
@@ -322,6 +361,7 @@ class Llama4VisionEncoder(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -330,6 +370,7 @@ class Llama4VisionEncoder(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
) for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@@ -357,23 +398,33 @@ class Llama4VisionEncoder(nn.Module):
|
||||
|
||||
class Llama4UnfoldConvolution(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = config.patch_size
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
||||
stride=config.patch_size)
|
||||
self.linear = ColumnParallelLinear(config.num_channels *
|
||||
kernel_size[0] * kernel_size[1],
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix=f"{prefix}.linear")
|
||||
params = {
|
||||
"input_size":
|
||||
config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
"output_size": config.hidden_size,
|
||||
"bias": False,
|
||||
"quant_config": quant_config,
|
||||
"prefix": f"{prefix}.linear",
|
||||
}
|
||||
if use_data_parallel:
|
||||
cls = ReplicatedLinear
|
||||
else:
|
||||
cls = ColumnParallelLinear
|
||||
params["gather_output"] = True
|
||||
self.linear = cls(**params)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.unfold(hidden_states)
|
||||
@@ -389,6 +440,7 @@ class Llama4VisionModel(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -403,7 +455,9 @@ class Llama4VisionModel(nn.Module):
|
||||
self.patch_embedding = Llama4UnfoldConvolution(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.patch_embedding")
|
||||
prefix=f"{prefix}.patch_embedding",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.class_embedding = nn.Parameter(self.scale *
|
||||
torch.randn(self.hidden_size))
|
||||
@@ -415,11 +469,18 @@ class Llama4VisionModel(nn.Module):
|
||||
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
|
||||
# encoders
|
||||
self.model = Llama4VisionEncoder(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.model")
|
||||
self.model = Llama4VisionEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.vision_adapter = Llama4VisionPixelShuffleMLP(
|
||||
config, quant_config, prefix=f"{prefix}.vision_adapter")
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision_adapter",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -528,8 +589,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
vision_config = self.info.get_hf_config().vision_config
|
||||
|
||||
if processed_outputs.get("pixel_values") is not None:
|
||||
assert "images" in mm_data, \
|
||||
"images expected to be in mm_data when pixel_values is present"
|
||||
assert (
|
||||
"images" in mm_data
|
||||
), "images expected to be in mm_data when pixel_values is present"
|
||||
|
||||
images = mm_data["images"]
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
@@ -546,8 +608,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
get_best_fit(
|
||||
(image.size[1], image.size[0]),
|
||||
torch.tensor(possible_resolutions),
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas)
|
||||
for image in parsed_images
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
||||
) for image in parsed_images
|
||||
]
|
||||
# TODO tile height/width do not necessarily need to match
|
||||
aspect_ratios = [(image_size[0] // tile_size,
|
||||
@@ -659,13 +721,17 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.vision_model = Llama4VisionModel(config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "vision_model"))
|
||||
self.vision_model = Llama4VisionModel(
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
self.config,
|
||||
None,
|
||||
@@ -709,7 +775,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
flat_data = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
# shard image input
|
||||
if self.use_data_parallel:
|
||||
vision_embeddings_flat = run_dp_sharded_vision_model(
|
||||
flat_data, self.vision_model)
|
||||
else:
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
|
||||
vision_embeddings_flat = self.multi_modal_projector(
|
||||
vision_embeddings_flat)
|
||||
|
||||
@@ -796,6 +868,30 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return get_prefix_weights(), get_other_weights()
|
||||
|
||||
def _consolidate_qkv_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
qkv_idx_mappings = {
|
||||
".self_attn.q_proj": 0,
|
||||
".self_attn.k_proj": 1,
|
||||
".self_attn.v_proj": 2,
|
||||
}
|
||||
qkv_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
for weight_name, idx in qkv_idx_mappings.items():
|
||||
if weight_name not in name:
|
||||
continue
|
||||
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
|
||||
if new_name not in qkv_weights:
|
||||
qkv_weights[new_name] = [None] * 3
|
||||
qkv_weights[new_name][idx] = loaded_weight
|
||||
break
|
||||
else:
|
||||
yield name, loaded_weight
|
||||
for key, weight in qkv_weights.items():
|
||||
qkv_weight = torch.cat(weight, dim=0)
|
||||
yield key, qkv_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
@@ -818,9 +914,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
assert loaded_language_model_params is not None
|
||||
updated_params.update(loaded_language_model_params)
|
||||
|
||||
if self.use_data_parallel:
|
||||
other_weights = self._consolidate_qkv_weights(other_weights)
|
||||
|
||||
for name, loaded_weight in other_weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
if weight_name not in name or self.use_data_parallel:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
|
||||
Reference in New Issue
Block a user