[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
WeiQing Chen
2025-09-02 00:56:56 +08:00
committed by GitHub
parent cf91a89dd2
commit a0e0efd6bd
6 changed files with 156 additions and 61 deletions

View File

@@ -56,6 +56,7 @@ from transformers.activations import GELUActivation
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
@@ -76,6 +77,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
@@ -93,8 +95,10 @@ class MaxImageTokenMeta:
class KimiVLMultiModalProjector(nn.Module):
def __init__(self, config: KimiVLConfig):
def __init__(self, config: KimiVLConfig, \
use_data_parallel: bool = False, prefix: str = ""):
super().__init__()
self.use_data_parallel = use_data_parallel
self.hidden_size = (config.vision_config.hidden_size *
config.vision_config.merge_kernel_size[0] *
@@ -102,20 +106,24 @@ class KimiVLMultiModalProjector(nn.Module):
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size,
eps=1e-5)
self.linear_1 = nn.Linear(self.hidden_size,
self.hidden_size,
bias=True)
self.linear_1 = ReplicatedLinear(self.hidden_size,
self.hidden_size,
bias=True,
prefix=maybe_prefix(
prefix, "linear_1"))
self.linear_2 = ReplicatedLinear(self.hidden_size,
config.text_config.hidden_size,
bias=True,
prefix=maybe_prefix(
prefix, "linear_2"))
self.act = GELUActivation()
self.linear_2 = nn.Linear(self.hidden_size,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(image_features).view(
-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states, _ = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states
@@ -273,6 +281,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
@@ -292,10 +302,17 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
quant_config = vllm_config.quant_config
assert isinstance(config.vision_config, MoonViTConfig)
self.use_data_parallel = model_config.multimodal_config.mm_encoder_tp_mode == "data"
self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel(config.vision_config,
self.use_data_parallel,
prefix=maybe_prefix(
prefix, "vision_tower"))
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
self.multi_modal_projector = KimiVLMultiModalProjector(
config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.quant_config = quant_config
sub_vllm_config = copy.deepcopy(vllm_config)
@@ -376,13 +393,19 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = inputs["pixel_values"]
image_grid_hws = inputs["image_grid_hws"]
return self.vision_tower(pixel_values, image_grid_hws)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.vision_tower,
pixel_values,
image_grid_hws.tolist(),
rope_type="rope_2d")
else:
return self.vision_tower(pixel_values, image_grid_hws)
def _process_image_input(self,
image_input: KimiVLImageInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values"
image_features = self._process_image_pixels(image_input)
assert isinstance(image_features, list)
assert isinstance(image_features, (list, tuple))
lengths = [x.shape[0] for x in image_features]
return self.multi_modal_projector(
torch.cat(image_features)).split(lengths)
@@ -496,6 +519,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
expert_params_mapping = []
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}