[Chore] Remove use_data_parallel kwargs from ViT implementation (#33310)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-29 18:20:52 +08:00
committed by GitHub
parent 3a92c6f3b5
commit 5400014d55
9 changed files with 36 additions and 89 deletions

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from .step3_vl import Step3VLForConditionalGeneration
from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix
from .vision import run_dp_sharded_vision_model
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5)
@@ -151,9 +151,9 @@ class PerceptionEncoderMLP(nn.Module):
act_layer: Callable[[], nn.Module],
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear(
input_dim,
hidden_dim,
@@ -189,7 +189,6 @@ class PerceptionEncoderVisionAttention(nn.Module):
use_cls_token: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
@@ -197,6 +196,7 @@ class PerceptionEncoderVisionAttention(nn.Module):
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim**-0.5
use_data_parallel = is_vit_use_data_parallel()
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0, (
"embed_dim must be divisible by num_heads"
@@ -258,7 +258,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
use_cls_token: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.attn = PerceptionEncoderVisionAttention(
@@ -269,7 +268,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
use_cls_token=use_cls_token,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
)
self.ls_1 = (
PerceptionEncoderLayerScale(d_model, ls_init_value)
@@ -290,7 +288,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
act_layer,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
@@ -314,7 +311,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
use_cls_token: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.width = width
@@ -333,7 +329,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
use_cls_token=use_cls_token,
quant_config=quant_config,
prefix=f"{prefix}.resblocks.{i}",
use_data_parallel=use_data_parallel,
)
for i in range(layers)
]
@@ -353,7 +348,6 @@ class PerceptionEncoder(nn.Module):
norm_layer: Callable = _DEFAULT_NORM_LAYER,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.patch_size = config.patch_size
@@ -394,7 +388,6 @@ class PerceptionEncoder(nn.Module):
use_cls_token=self.use_cls_token,
quant_config=quant_config,
prefix=f"{prefix}.transformer",
use_data_parallel=use_data_parallel,
)
self.vit_downsampler1 = Conv2dLayer(
@@ -511,7 +504,6 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
get_act_fn(config.vision_config.hidden_act),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
self.vit_large_projector = ColumnParallelLinear(
config.vision_config.width * 4,