[Model][VLM] Add Qwen2.5-Omni model support (thinker only) (#15130)
Signed-off-by: fyabc <suyang.fy@alibaba-inc.com> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: Xiong Wang <wangxiongts@163.com>
This commit is contained in:
@@ -38,13 +38,14 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
@@ -195,6 +196,23 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
return x_down
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||
dist.all_gather(gathered_tensors, local_tensor)
|
||||
|
||||
gathered_tensors_split = [
|
||||
torch.split(tensor, hidden_size // tp_size, -1)
|
||||
for tensor in gathered_tensors
|
||||
]
|
||||
ordered_tensors = [
|
||||
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
||||
]
|
||||
result_tensor = torch.cat(ordered_tensors, dim=-1)
|
||||
return result_tensor
|
||||
|
||||
|
||||
class Qwen2_5_VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -214,10 +232,14 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
||||
output_size=3 * projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
self.proj = RowParallelLinear(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
@@ -236,7 +258,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = tensor_model_parallel_all_gather(qkv)
|
||||
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
|
||||
self.tp_size)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
@@ -694,9 +717,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("attn.qkv.", "attn.q.", "q"),
|
||||
("attn.qkv.", "attn.k.", "k"),
|
||||
("attn.qkv.", "attn.v.", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
@@ -952,20 +975,20 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
mm_input_by_modality = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("pixel_values_videos",
|
||||
"video_embeds") and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
return modalities
|
||||
if input_key in ("pixel_values", "image_embeds"
|
||||
) and "image" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key in ("pixel_values_videos", "video_embeds"
|
||||
) and "video" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"video"] = self._parse_and_validate_video_input(**kwargs)
|
||||
return mm_input_by_modality
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
@@ -973,8 +996,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
||||
**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
return None
|
||||
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
@@ -983,14 +1007,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
for modality in mm_input_by_modality:
|
||||
multimodal_input = mm_input_by_modality[modality]
|
||||
if modality == "image":
|
||||
vision_embeddings = self._process_image_input(multimodal_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
if modality == "video":
|
||||
video_embeddings = self._process_video_input(multimodal_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
|
||||
Reference in New Issue
Block a user