[Chore] Remove use_data_parallel kwargs from ViT implementation (#33310)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -54,7 +54,7 @@ from .utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Step3VLImagePixelInputs(TensorSchema):
|
||||
@@ -724,7 +724,6 @@ class Step3VisionAttention(nn.Module):
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -734,6 +733,7 @@ class Step3VisionAttention(nn.Module):
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
@@ -786,11 +786,11 @@ class Step3VisionMLP(nn.Module):
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
@@ -821,23 +821,19 @@ class Step3VisionEncoderLayer(nn.Module):
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Step3VisionAttention(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Step3VisionMLP(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@@ -856,18 +852,15 @@ class Step3VisionEncoder(nn.Module):
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Step3VisionEncoderLayer(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.layers.{i}",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -889,18 +882,16 @@ class Step3VisionTransformer(nn.Module):
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.use_data_parallel = is_vit_use_data_parallel()
|
||||
self.image_size = config.image_size
|
||||
self.embeddings = Step3VisionEmbeddings(config)
|
||||
self.transformer = Step3VisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -952,7 +943,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.vit_downsampler = Conv2dLayer(
|
||||
config.vision_config.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user