[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)
Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -56,6 +56,7 @@ from transformers.activations import GELUActivation
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
@@ -76,6 +77,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
||||
@@ -93,8 +95,10 @@ class MaxImageTokenMeta:
|
||||
|
||||
class KimiVLMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, config: KimiVLConfig):
|
||||
def __init__(self, config: KimiVLConfig, \
|
||||
use_data_parallel: bool = False, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.use_data_parallel = use_data_parallel
|
||||
|
||||
self.hidden_size = (config.vision_config.hidden_size *
|
||||
config.vision_config.merge_kernel_size[0] *
|
||||
@@ -102,20 +106,24 @@ class KimiVLMultiModalProjector(nn.Module):
|
||||
|
||||
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
|
||||
eps=1e-5)
|
||||
self.linear_1 = nn.Linear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True)
|
||||
self.linear_1 = ReplicatedLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "linear_1"))
|
||||
self.linear_2 = ReplicatedLinear(self.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "linear_2"))
|
||||
self.act = GELUActivation()
|
||||
self.linear_2 = nn.Linear(self.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=True)
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(image_features).view(
|
||||
-1, self.hidden_size)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states, _ = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states, _ = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -273,6 +281,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
|
||||
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
@@ -292,10 +302,17 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
assert isinstance(config.vision_config, MoonViTConfig)
|
||||
self.use_data_parallel = model_config.multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
self.vision_tower = MoonVitPretrainedModel(config.vision_config,
|
||||
self.use_data_parallel,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "vision_tower"))
|
||||
|
||||
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
|
||||
self.multi_modal_projector = KimiVLMultiModalProjector(
|
||||
config=config,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
|
||||
self.quant_config = quant_config
|
||||
sub_vllm_config = copy.deepcopy(vllm_config)
|
||||
@@ -376,13 +393,19 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
image_grid_hws = inputs["image_grid_hws"]
|
||||
return self.vision_tower(pixel_values, image_grid_hws)
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(self.vision_tower,
|
||||
pixel_values,
|
||||
image_grid_hws.tolist(),
|
||||
rope_type="rope_2d")
|
||||
else:
|
||||
return self.vision_tower(pixel_values, image_grid_hws)
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: KimiVLImageInputs) -> torch.Tensor:
|
||||
assert image_input["type"] == "pixel_values"
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
assert isinstance(image_features, list)
|
||||
assert isinstance(image_features, (list, tuple))
|
||||
lengths = [x.shape[0] for x in image_features]
|
||||
return self.multi_modal_projector(
|
||||
torch.cat(image_features)).split(lengths)
|
||||
@@ -496,6 +519,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
expert_params_mapping = []
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for args in weights:
|
||||
name, loaded_weight = args[:2]
|
||||
kwargs = args[2] if len(args) > 2 else {}
|
||||
|
||||
Reference in New Issue
Block a user