[VLM] Add TP support for Phi-4-MM (#14453)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-03-08 21:57:14 +08:00
committed by GitHub
parent cb8bdfade2
commit 03fe18ae0f
4 changed files with 50 additions and 295 deletions

View File

@@ -15,7 +15,7 @@ from transformers import PretrainedConfig
from transformers.utils import logging
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.inputs.data import TokenInputs, token_inputs
@@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding
from .utils import maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .vision_siglip_navit import get_siglip_vision_model
# <|endoftext10|> (see vocab.json in hf model)
@@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module):
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
embd_drop = config.embd_pdrop if hasattr(
config, 'embd_pdrop') else config.embed_pdrop
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# layer_idx to output the img features
if isinstance(config.img_processor, dict):
@@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"base_layer.": "",
},
orig_to_new_prefix={
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
"embed_tokens_extend.audio_projection_for_vision.",
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
"embed_tokens_extend.audio_projection.",
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
self.lora_config = lora_config
# Tensor/Pipeline parallel not supported for now.
assert get_tensor_model_parallel_world_size(
) == 1, "tensor parallel is not supported"
assert get_pp_group(
).world_size == 1, "pipeline parallel is not supported"
@@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
)
return merged_embeds
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
# NOTE vision-speech tasks use a separate projection layer
audio_proj_4v = \
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
if name.startswith(audio_proj_4v):
name = name.replace(
audio_proj_4v,
"embed_tokens_extend.audio_projection_for_vision")
name = (name.replace(
"model.embed_tokens_extend.audio_embed."\
"audio_projection.speech.",
"embed_tokens_extend.audio_projection.",
).replace(
"model.embed_tokens_extend.audio_embed.",
"embed_tokens_extend.",
).replace("model.embed_tokens_extend.image_embed.",
"vision_encoder."))
# NOTE: this is deal with LoRA injection, where `base_layer`
# remains as the original layer in the model
if name.endswith(".base_layer.weight"):
name = name.replace(".base_layer.weight", ".weight")
adjusted_weights[name] = weight
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
logger.debug("*** missing keys:")
for key in missing_keys:
logger.debug(key)
logger.debug("**** unexpected keys:")
for key in unexpected_keys:
logger.debug(key)
def forward(
self,
input_ids: torch.Tensor,
@@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
@@ -1804,4 +1779,4 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
language_model="model.",
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)
)