[Model] Re-implement Qwen3Omni Audio Encoder (#32167)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -31,29 +31,34 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from packaging.version import Version
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
|
||||
Qwen3OmniMoeAudioEncoderConfig,
|
||||
Qwen3OmniMoeConfig,
|
||||
Qwen3OmniMoeThinkerConfig,
|
||||
)
|
||||
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
|
||||
Qwen3OmniMoeAudioEncoder,
|
||||
)
|
||||
from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
|
||||
Qwen3OmniMoeProcessor,
|
||||
)
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
# isort: off
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
# isort: on
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -104,11 +109,6 @@ from .vision import (
|
||||
get_vit_attn_backend,
|
||||
)
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
flash_attn = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -121,6 +121,415 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
return output_lengths
|
||||
|
||||
|
||||
# ============= Audio Encoder Components =============
|
||||
|
||||
|
||||
class SinusoidsPositionEmbedding(nn.Module):
|
||||
"""Sinusoidal position embedding for audio encoder."""
|
||||
|
||||
def __init__(self, length: int, channels: int, max_timescale: int = 10000):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.channels = channels
|
||||
self.max_timescale = max_timescale
|
||||
|
||||
if channels % 2 != 0:
|
||||
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
||||
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(
|
||||
-log_timescale_increment * torch.arange(channels // 2).float()
|
||||
)
|
||||
scaled_time = (
|
||||
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
)
|
||||
positional_embedding = torch.cat(
|
||||
[torch.sin(scaled_time), torch.cos(scaled_time)], dim=1
|
||||
)
|
||||
self.register_buffer(
|
||||
"positional_embedding", positional_embedding, persistent=False
|
||||
)
|
||||
|
||||
def forward(self, seqlen: int) -> torch.Tensor:
|
||||
return self.positional_embedding[:seqlen, :]
|
||||
|
||||
|
||||
class Qwen3OmniMoeAudioAttention(nn.Module):
|
||||
"""Multi-headed attention for Qwen3-Omni Audio Encoder using MMEncoderAttention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3OmniMoeAudioEncoderConfig,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.num_heads = config.encoder_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_local_heads = self.num_heads // tp_size
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
|
||||
f"{self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
total_num_kv_heads=self.num_heads,
|
||||
bias=True,
|
||||
prefix=f"{prefix}.qkv",
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
bias=True,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
seq_length, _ = hidden_states.size()
|
||||
qkv, _ = self.qkv(hidden_states)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
q = q.view(1, seq_length, -1, self.head_dim)
|
||||
k = k.view(1, seq_length, -1, self.head_dim)
|
||||
v = v.view(1, seq_length, -1, self.head_dim)
|
||||
|
||||
attn_output = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(seq_length, -1)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
|
||||
"""Transformer encoder layer for Qwen3-Omni Audio Encoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3OmniMoeAudioEncoderConfig,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = Qwen3OmniMoeAudioAttention(
|
||||
config, multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn"
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function]
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
self.embed_dim,
|
||||
config.encoder_ffn_dim,
|
||||
bias=True,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.encoder_ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: Input tensor of shape (seq_len, hidden_size)
|
||||
cu_seqlens: Cumulative sequence lengths
|
||||
max_seqlen: Maximum sequence length in the batch
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Clamp for numerical stability with fp16
|
||||
if hidden_states.dtype == torch.float16:
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(
|
||||
hidden_states, min=-clamp_value, max=clamp_value
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen3OmniMoeAudioEncoder(nn.Module):
|
||||
"""vLLM-native Qwen3-Omni Audio Encoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3OmniMoeAudioEncoderConfig,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.d_model
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.n_window = config.n_window
|
||||
self.n_window_infer = config.n_window_infer
|
||||
self.conv_chunksize = config.conv_chunksize
|
||||
|
||||
# Position embedding
|
||||
self.positional_embedding = SinusoidsPositionEmbedding(
|
||||
self.max_source_positions, embed_dim
|
||||
)
|
||||
|
||||
# Convolutional layers for mel-spectrogram processing
|
||||
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
|
||||
self.conv2d2 = nn.Conv2d(
|
||||
config.downsample_hidden_size,
|
||||
config.downsample_hidden_size,
|
||||
3,
|
||||
2,
|
||||
padding=1,
|
||||
)
|
||||
self.conv2d3 = nn.Conv2d(
|
||||
config.downsample_hidden_size,
|
||||
config.downsample_hidden_size,
|
||||
3,
|
||||
2,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
conv_out_dim = config.downsample_hidden_size * (
|
||||
(((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2
|
||||
)
|
||||
self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False)
|
||||
|
||||
# Transformer encoder layers
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen3OmniMoeAudioEncoderLayer(
|
||||
config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{i}",
|
||||
)
|
||||
for i in range(config.encoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Output layers
|
||||
self.ln_post = nn.LayerNorm(config.d_model)
|
||||
self.proj1 = nn.Linear(config.d_model, config.d_model)
|
||||
self.act = _ACTIVATION_REGISTRY[config.activation_function]
|
||||
self.proj2 = nn.Linear(config.d_model, config.output_dim)
|
||||
|
||||
# Get attention backend
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=config.d_model // config.encoder_attention_heads,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Compute max_seqlen only for flash attention backends."""
|
||||
max_seqlen = None
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
return max_seqlen
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.conv2d1.weight.dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.conv2d1.weight.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
feature_lens: torch.Tensor,
|
||||
aftercnn_lens: torch.Tensor,
|
||||
):
|
||||
# Compute chunk information
|
||||
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
||||
|
||||
chunk_lengths = torch.tensor(
|
||||
[self.n_window * 2] * chunk_num.sum(),
|
||||
dtype=torch.long,
|
||||
device=feature_lens.device,
|
||||
)
|
||||
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
|
||||
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
|
||||
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
|
||||
|
||||
# Split input features into chunks and pad
|
||||
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
|
||||
padded_feature = nn.utils.rnn.pad_sequence(
|
||||
chunk_list, batch_first=True
|
||||
).transpose(1, 2)
|
||||
|
||||
# Compute feature lengths after CNN
|
||||
feature_lens_after_cnn = self._get_cnn_output_lengths(chunk_lengths)
|
||||
# Vectorized mask creation: avoid creating many small tensors
|
||||
max_len_after_cnn = feature_lens_after_cnn.max().item()
|
||||
indices = torch.arange(max_len_after_cnn, device=padded_feature.device)
|
||||
padded_mask_after_cnn = indices.unsqueeze(0) < feature_lens_after_cnn.unsqueeze(
|
||||
1
|
||||
)
|
||||
|
||||
# Add channel dimension for conv2d
|
||||
padded_feature = padded_feature.unsqueeze(1)
|
||||
|
||||
# Apply convolutional layers (chunk if needed to avoid OOM)
|
||||
if padded_feature.size(0) <= self.conv_chunksize:
|
||||
# Fast path: no chunking needed
|
||||
padded_embed = F.gelu(self.conv2d1(padded_feature))
|
||||
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
||||
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
||||
else:
|
||||
# Chunked processing to avoid OOM
|
||||
padded_embeds = []
|
||||
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
|
||||
padded_embed = F.gelu(self.conv2d1(chunk))
|
||||
padded_embed = F.gelu(self.conv2d2(padded_embed))
|
||||
padded_embed = F.gelu(self.conv2d3(padded_embed))
|
||||
padded_embeds.append(padded_embed)
|
||||
padded_embed = torch.cat(padded_embeds, dim=0)
|
||||
|
||||
# (batch, channels, freq, time) -> (batch, time, channels*freq)
|
||||
b, c, f, t = padded_embed.size()
|
||||
padded_embed = self.conv_out(
|
||||
padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
|
||||
)
|
||||
|
||||
# Add positional embedding
|
||||
positional_embedding = (
|
||||
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
|
||||
.unsqueeze(0)
|
||||
.to(padded_embed.dtype)
|
||||
)
|
||||
padded_embed = padded_embed + positional_embedding
|
||||
|
||||
# Extract valid hidden states and compute cu_seqlens
|
||||
hidden_states = padded_embed[padded_mask_after_cnn]
|
||||
|
||||
# Compute cumulative sequence lengths for chunked attention
|
||||
cu_chunk_lens = [0]
|
||||
window_aftercnn = padded_mask_after_cnn.shape[-1] * (
|
||||
self.n_window_infer // (self.n_window * 2)
|
||||
)
|
||||
# Use tolist() for efficient batch conversion from tensor to Python
|
||||
for cnn_len in aftercnn_lens.tolist():
|
||||
num_full_chunks = cnn_len // window_aftercnn
|
||||
remainder = cnn_len % window_aftercnn
|
||||
cu_chunk_lens.extend([window_aftercnn] * num_full_chunks)
|
||||
if remainder:
|
||||
cu_chunk_lens.append(remainder)
|
||||
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
|
||||
-1, dtype=torch.int32
|
||||
)
|
||||
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
# Apply transformer layers
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
|
||||
# Apply output layers
|
||||
hidden_states = self.ln_post(hidden_states)
|
||||
hidden_states = self.proj1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.proj2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _get_cnn_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after the three conv2d layers."""
|
||||
lengths = input_lengths
|
||||
for _ in range(3):
|
||||
lengths = (lengths - 1) // 2 + 1
|
||||
return lengths
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load weights with mapping from HuggingFace format."""
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("self_attn.qkv.", "self_attn.q_proj.", "q"),
|
||||
("self_attn.qkv.", "self_attn.k_proj.", "k"),
|
||||
("self_attn.qkv.", "self_attn.v_proj.", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict.get(name)
|
||||
if param is not None:
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -144,7 +553,7 @@ class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
L, C = x.shape
|
||||
L, _ = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.hidden_size)
|
||||
return x
|
||||
@@ -224,7 +633,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor, # Only used for Flash Attention
|
||||
max_seqlen: torch.Tensor | None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(
|
||||
self.norm1(x),
|
||||
@@ -1142,12 +1551,11 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
|
||||
|
||||
audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
|
||||
|
||||
audio_outputs = self.audio_tower(
|
||||
audio_features = self.audio_tower(
|
||||
input_features.to(self.audio_tower.dtype),
|
||||
feature_lens=audio_feature_lengths,
|
||||
aftercnn_lens=audio_output_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
return audio_features.split(audio_output_lengths.tolist())
|
||||
|
||||
|
||||
@@ -1205,21 +1613,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
self.config = thinker_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# force "use_flash_attention_2=True" to audio tower to align
|
||||
# the results.
|
||||
if flash_attn is not None:
|
||||
audio_config = thinker_config.audio_config
|
||||
audio_config._attn_implementation_autoset = True
|
||||
audio_config._attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
logger.warning(
|
||||
"flash_attn is not available, the model may not yield the "
|
||||
"exactly same result as the transformers implementation "
|
||||
"in the audio tower part."
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(
|
||||
thinker_config.audio_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
|
||||
|
||||
self.visual = Qwen3Omni_VisionTransformer(
|
||||
vision_config=thinker_config.vision_config,
|
||||
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||
|
||||
Reference in New Issue
Block a user