[VLM] Add TP support for Phi-4-MM (#14453)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.inputs.data import TokenInputs, token_inputs
|
||||
@@ -34,7 +34,7 @@ from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||
from .phi4mm_audio import AudioEmbedding
|
||||
from .utils import maybe_prefix
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
from .vision_siglip_navit import get_siglip_vision_model
|
||||
|
||||
# <|endoftext10|> (see vocab.json in hf model)
|
||||
@@ -352,12 +352,6 @@ class Phi4MMImageEncoder(nn.Module):
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
|
||||
embd_drop = config.embd_pdrop if hasattr(
|
||||
config, 'embd_pdrop') else config.embed_pdrop
|
||||
self.drop = nn.Dropout(embd_drop)
|
||||
else:
|
||||
self.drop = None
|
||||
|
||||
# layer_idx to output the img features
|
||||
if isinstance(config.img_processor, dict):
|
||||
@@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"base_layer.": "",
|
||||
},
|
||||
orig_to_new_prefix={
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
|
||||
"embed_tokens_extend.audio_projection_for_vision.",
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
|
||||
"embed_tokens_extend.audio_projection.",
|
||||
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
|
||||
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -1445,8 +1453,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
self.lora_config = lora_config
|
||||
|
||||
# Tensor/Pipeline parallel not supported for now.
|
||||
assert get_tensor_model_parallel_world_size(
|
||||
) == 1, "tensor parallel is not supported"
|
||||
assert get_pp_group(
|
||||
).world_size == 1, "pipeline parallel is not supported"
|
||||
|
||||
@@ -1686,44 +1692,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
)
|
||||
return merged_embeds
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
weights = {name: weight for name, weight in weights}
|
||||
adjusted_weights = {}
|
||||
|
||||
for name, weight in weights.items():
|
||||
# NOTE vision-speech tasks use a separate projection layer
|
||||
audio_proj_4v = \
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
|
||||
if name.startswith(audio_proj_4v):
|
||||
name = name.replace(
|
||||
audio_proj_4v,
|
||||
"embed_tokens_extend.audio_projection_for_vision")
|
||||
|
||||
name = (name.replace(
|
||||
"model.embed_tokens_extend.audio_embed."\
|
||||
"audio_projection.speech.",
|
||||
"embed_tokens_extend.audio_projection.",
|
||||
).replace(
|
||||
"model.embed_tokens_extend.audio_embed.",
|
||||
"embed_tokens_extend.",
|
||||
).replace("model.embed_tokens_extend.image_embed.",
|
||||
"vision_encoder."))
|
||||
# NOTE: this is deal with LoRA injection, where `base_layer`
|
||||
# remains as the original layer in the model
|
||||
if name.endswith(".base_layer.weight"):
|
||||
name = name.replace(".base_layer.weight", ".weight")
|
||||
adjusted_weights[name] = weight
|
||||
|
||||
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
|
||||
strict=False)
|
||||
logger.debug("*** missing keys:")
|
||||
for key in missing_keys:
|
||||
logger.debug(key)
|
||||
logger.debug("**** unexpected keys:")
|
||||
for key in unexpected_keys:
|
||||
logger.debug(key)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -1796,6 +1764,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
weights = ((name, data) for name, data in weights
|
||||
if "lora" not in name)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
@@ -1804,4 +1779,4 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
language_model="model.",
|
||||
connector=["audio_projection_for_vision", "audio_projection"],
|
||||
tower_model=["vision_encoder", "embed_tokens_extend"],
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user