[Chore] Remove use_data_parallel kwargs from ViT implementation (#33310)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-29 18:20:52 +08:00
committed by GitHub
parent 3a92c6f3b5
commit 5400014d55
9 changed files with 36 additions and 89 deletions

View File

@@ -54,7 +54,7 @@ from .utils import (
init_vllm_registered_model,
maybe_prefix,
)
from .vision import run_dp_sharded_vision_model
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
class Step3VLImagePixelInputs(TensorSchema):
@@ -724,7 +724,6 @@ class Step3VisionAttention(nn.Module):
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
@@ -734,6 +733,7 @@ class Step3VisionAttention(nn.Module):
self.scale = self.head_dim**-0.5
use_data_parallel = is_vit_use_data_parallel()
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
@@ -786,11 +786,11 @@ class Step3VisionMLP(nn.Module):
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
use_data_parallel = is_vit_use_data_parallel()
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
@@ -821,23 +821,19 @@ class Step3VisionEncoderLayer(nn.Module):
config: Step3VisionEncoderConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.use_data_parallel = use_data_parallel
self.embed_dim = config.hidden_size
self.self_attn = Step3VisionAttention(
config,
quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=self.use_data_parallel,
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Step3VisionMLP(
config,
quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -856,18 +852,15 @@ class Step3VisionEncoder(nn.Module):
config: Step3VisionEncoderConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.layers = nn.ModuleList(
[
Step3VisionEncoderLayer(
config,
quant_config,
prefix=f"{prefix}.layers.{i}",
use_data_parallel=self.use_data_parallel,
)
for i in range(config.num_hidden_layers)
]
@@ -889,18 +882,16 @@ class Step3VisionTransformer(nn.Module):
config: Step3VisionEncoderConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.use_data_parallel = is_vit_use_data_parallel()
self.image_size = config.image_size
self.embeddings = Step3VisionEmbeddings(config)
self.transformer = Step3VisionEncoder(
config,
quant_config,
prefix=f"{prefix}.transformer",
use_data_parallel=self.use_data_parallel,
)
def forward(
@@ -952,7 +943,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
self.vit_downsampler = Conv2dLayer(
config.vision_config.hidden_size,