[VLM] Migrate remain DP-supported ViT models to use disable_tp (#24363)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -106,22 +106,21 @@ class Llama4VisionMLP(nn.Module):
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
input_size=input_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
|
||||
self.fc2 = cls_fc2(
|
||||
self.fc2 = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.activation_fn = nn.GELU()
|
||||
self.output_activation = output_activation
|
||||
@@ -419,20 +418,15 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
||||
stride=config.patch_size)
|
||||
params = {
|
||||
"input_size":
|
||||
config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
"output_size": config.hidden_size,
|
||||
"bias": False,
|
||||
"quant_config": quant_config,
|
||||
"prefix": f"{prefix}.linear",
|
||||
}
|
||||
if use_data_parallel:
|
||||
cls = ReplicatedLinear
|
||||
else:
|
||||
cls = ColumnParallelLinear
|
||||
params["gather_output"] = True
|
||||
self.linear = cls(**params)
|
||||
self.linear = ColumnParallelLinear(
|
||||
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.unfold(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user