[Chore] Remove use_data_parallel kwargs from ViT implementation (#33310)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
from .step3_vl import Step3VLForConditionalGeneration
|
||||
from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
|
||||
|
||||
_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5)
|
||||
|
||||
@@ -151,9 +151,9 @@ class PerceptionEncoderMLP(nn.Module):
|
||||
act_layer: Callable[[], nn.Module],
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
@@ -189,7 +189,6 @@ class PerceptionEncoderVisionAttention(nn.Module):
|
||||
use_cls_token: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@@ -197,6 +196,7 @@ class PerceptionEncoderVisionAttention(nn.Module):
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_size == 0, (
|
||||
"embed_dim must be divisible by num_heads"
|
||||
@@ -258,7 +258,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
|
||||
use_cls_token: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = PerceptionEncoderVisionAttention(
|
||||
@@ -269,7 +268,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
|
||||
use_cls_token=use_cls_token,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.ls_1 = (
|
||||
PerceptionEncoderLayerScale(d_model, ls_init_value)
|
||||
@@ -290,7 +288,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
|
||||
act_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
|
||||
@@ -314,7 +311,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
|
||||
use_cls_token: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
@@ -333,7 +329,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
|
||||
use_cls_token=use_cls_token,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.resblocks.{i}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for i in range(layers)
|
||||
]
|
||||
@@ -353,7 +348,6 @@ class PerceptionEncoder(nn.Module):
|
||||
norm_layer: Callable = _DEFAULT_NORM_LAYER,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = config.patch_size
|
||||
@@ -394,7 +388,6 @@ class PerceptionEncoder(nn.Module):
|
||||
use_cls_token=self.use_cls_token,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.vit_downsampler1 = Conv2dLayer(
|
||||
@@ -511,7 +504,6 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
|
||||
get_act_fn(config.vision_config.hidden_act),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.vit_large_projector = ColumnParallelLinear(
|
||||
config.vision_config.width * 4,
|
||||
|
||||
Reference in New Issue
Block a user