[Chore] Remove use_data_parallel kwargs from ViT implementation (#33310)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Idefics2VisionEmbeddings(nn.Module):
|
||||
@@ -126,9 +126,9 @@ class Idefics2VisionAttention(nn.Module):
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
@@ -187,11 +187,12 @@ class Idefics2VisionMLP(nn.Module):
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
@@ -222,7 +223,6 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
config: Idefics2Config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -230,14 +230,12 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Idefics2VisionMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@@ -279,7 +277,6 @@ class Idefics2Encoder(nn.Module):
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -296,7 +293,6 @@ class Idefics2Encoder(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
@@ -331,20 +327,18 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool = True,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.use_data_parallel = is_vit_use_data_parallel()
|
||||
self.embeddings = Idefics2VisionEmbeddings(config)
|
||||
self.encoder = Idefics2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
|
||||
|
||||
NORM2FN = {
|
||||
"rms_norm": RMSNorm,
|
||||
@@ -148,7 +148,6 @@ class InternParallelAttention(nn.Module):
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -163,9 +162,14 @@ class InternParallelAttention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
# if the number of heads is not divisible by tp_size,
|
||||
# we also disable Attention's TP
|
||||
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
use_data_parallel = (
|
||||
use_data_parallel or (self.num_heads + num_dummy_heads) % tp_size != 0
|
||||
)
|
||||
self.tp_size = 1 if use_data_parallel else tp_size
|
||||
self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
|
||||
|
||||
# Additional dummy heads are used to enable TP for common GPU counts.
|
||||
@@ -242,12 +246,12 @@ class InternMLP(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
@@ -281,11 +285,9 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_cls: type[InternParallelAttention] = InternParallelAttention,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
@@ -296,14 +298,12 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.mlp = InternMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
@@ -318,23 +318,12 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
*,
|
||||
num_dummy_heads: int,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
|
||||
# if the number of heads is not divisible by tp_size,
|
||||
# we also disable Attention's TP
|
||||
use_data_parallel = (
|
||||
use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0
|
||||
)
|
||||
return self.attn_cls(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=prefix,
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -357,7 +346,6 @@ class InternVisionEncoder(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -377,7 +365,6 @@ class InternVisionEncoder(nn.Module):
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
@@ -404,12 +391,11 @@ class InternVisionModel(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.use_data_parallel = is_vit_use_data_parallel()
|
||||
|
||||
self.embeddings = InternVisionEmbeddings(config)
|
||||
self.encoder = InternVisionEncoder(
|
||||
@@ -418,7 +404,6 @@ class InternVisionModel(nn.Module):
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@@ -1153,7 +1153,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
prefix=prefix,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
else:
|
||||
return InternVisionPatchModel(config.vision_config)
|
||||
|
||||
@@ -81,7 +81,7 @@ from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
from .vision import run_dp_sharded_mrope_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
|
||||
|
||||
|
||||
# For dummy input only
|
||||
@@ -93,10 +93,12 @@ class MaxImageTokenMeta:
|
||||
|
||||
class KimiVLMultiModalProjector(nn.Module):
|
||||
def __init__(
|
||||
self, config: KimiVLConfig, use_data_parallel: bool = False, prefix: str = ""
|
||||
self,
|
||||
config: KimiVLConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.use_data_parallel = is_vit_use_data_parallel()
|
||||
|
||||
self.hidden_size = (
|
||||
config.vision_config.hidden_size
|
||||
@@ -321,7 +323,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
self.multi_modal_projector = KimiVLMultiModalProjector(
|
||||
config=config,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ from .utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import is_vit_use_data_parallel
|
||||
|
||||
|
||||
class Lfm2VLImagePixelInputs(TensorSchema):
|
||||
@@ -426,10 +427,12 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
|
||||
|
||||
class Lfm2VLMultiModalProjector(nn.Module):
|
||||
def __init__(
|
||||
self, config: Lfm2VlConfig, use_data_parallel: bool = False, prefix: str = ""
|
||||
self,
|
||||
config: Lfm2VlConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.use_data_parallel = is_vit_use_data_parallel()
|
||||
|
||||
in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
|
||||
self.factor = config.downsample_factor
|
||||
@@ -607,7 +610,6 @@ class Lfm2VLForConditionalGeneration(
|
||||
|
||||
self.multi_modal_projector = Lfm2VLMultiModalProjector(
|
||||
config=config,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
|
||||
@@ -1335,7 +1335,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
@@ -1428,7 +1427,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
@@ -1526,7 +1524,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
@@ -1624,7 +1621,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
|
||||
@@ -79,7 +79,7 @@ from .interfaces import (
|
||||
)
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Llama4ImagePatchInputs(TensorSchema):
|
||||
@@ -124,9 +124,9 @@ class Llama4VisionMLP(nn.Module):
|
||||
output_activation: bool,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
input_size=input_size,
|
||||
output_size=intermediate_size,
|
||||
@@ -208,7 +208,6 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
||||
@@ -224,7 +223,6 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
output_activation=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
@@ -238,10 +236,10 @@ class Llama4VisionAttention(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
@@ -336,7 +334,6 @@ class Llama4VisionEncoderLayer(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -347,7 +344,6 @@ class Llama4VisionEncoderLayer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.mlp = Llama4VisionMLP(
|
||||
input_size=config.hidden_size,
|
||||
@@ -357,7 +353,6 @@ class Llama4VisionEncoderLayer(nn.Module):
|
||||
output_activation=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
@@ -389,7 +384,6 @@ class Llama4VisionEncoder(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -399,7 +393,6 @@ class Llama4VisionEncoder(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -432,13 +425,13 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = config.patch_size
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.linear = ColumnParallelLinear(
|
||||
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
output_size=config.hidden_size,
|
||||
@@ -465,7 +458,6 @@ class Llama4VisionModel(nn.Module):
|
||||
config: Llama4VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -481,7 +473,6 @@ class Llama4VisionModel(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.patch_embedding",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
|
||||
@@ -498,14 +489,12 @@ class Llama4VisionModel(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.vision_adapter = Llama4VisionPixelShuffleMLP(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision_adapter",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -780,7 +769,6 @@ class Llama4ForConditionalGeneration(
|
||||
config=config.vision_config,
|
||||
quant_config=None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
|
||||
@@ -54,7 +54,7 @@ from .utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Step3VLImagePixelInputs(TensorSchema):
|
||||
@@ -724,7 +724,6 @@ class Step3VisionAttention(nn.Module):
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -734,6 +733,7 @@ class Step3VisionAttention(nn.Module):
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
@@ -786,11 +786,11 @@ class Step3VisionMLP(nn.Module):
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
@@ -821,23 +821,19 @@ class Step3VisionEncoderLayer(nn.Module):
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Step3VisionAttention(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Step3VisionMLP(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@@ -856,18 +852,15 @@ class Step3VisionEncoder(nn.Module):
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Step3VisionEncoderLayer(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.layers.{i}",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -889,18 +882,16 @@ class Step3VisionTransformer(nn.Module):
|
||||
config: Step3VisionEncoderConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.use_data_parallel = is_vit_use_data_parallel()
|
||||
self.image_size = config.image_size
|
||||
self.embeddings = Step3VisionEmbeddings(config)
|
||||
self.transformer = Step3VisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -952,7 +943,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.vit_downsampler = Conv2dLayer(
|
||||
config.vision_config.hidden_size,
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
from .step3_vl import Step3VLForConditionalGeneration
|
||||
from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
|
||||
|
||||
_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5)
|
||||
|
||||
@@ -151,9 +151,9 @@ class PerceptionEncoderMLP(nn.Module):
|
||||
act_layer: Callable[[], nn.Module],
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
@@ -189,7 +189,6 @@ class PerceptionEncoderVisionAttention(nn.Module):
|
||||
use_cls_token: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@@ -197,6 +196,7 @@ class PerceptionEncoderVisionAttention(nn.Module):
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_size == 0, (
|
||||
"embed_dim must be divisible by num_heads"
|
||||
@@ -258,7 +258,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
|
||||
use_cls_token: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = PerceptionEncoderVisionAttention(
|
||||
@@ -269,7 +268,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
|
||||
use_cls_token=use_cls_token,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.ls_1 = (
|
||||
PerceptionEncoderLayerScale(d_model, ls_init_value)
|
||||
@@ -290,7 +288,6 @@ class PerceptionEncoderVisionBlock(nn.Module):
|
||||
act_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
|
||||
@@ -314,7 +311,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
|
||||
use_cls_token: bool = False,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
@@ -333,7 +329,6 @@ class PerceptionEncoderVisionTransformer(nn.Module):
|
||||
use_cls_token=use_cls_token,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.resblocks.{i}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for i in range(layers)
|
||||
]
|
||||
@@ -353,7 +348,6 @@ class PerceptionEncoder(nn.Module):
|
||||
norm_layer: Callable = _DEFAULT_NORM_LAYER,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = config.patch_size
|
||||
@@ -394,7 +388,6 @@ class PerceptionEncoder(nn.Module):
|
||||
use_cls_token=self.use_cls_token,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.vit_downsampler1 = Conv2dLayer(
|
||||
@@ -511,7 +504,6 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
|
||||
get_act_fn(config.vision_config.hidden_act),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.vit_large_projector = ColumnParallelLinear(
|
||||
config.vision_config.width * 4,
|
||||
|
||||
Reference in New Issue
Block a user