[Qwen3-Omni] fixed _get_feat_extract_output_lengths function (#31007)

Signed-off-by: Xiong Wang <wangxiongts@163.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Xiong Wang
2025-12-24 13:33:54 +08:00
committed by GitHub
parent 369f47aa0f
commit bb24592d13

View File

@@ -118,7 +118,7 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
output_lengths = ( output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
) )
return feat_lengths, output_lengths return output_lengths
class Qwen3_VisionPatchEmbed(nn.Module): class Qwen3_VisionPatchEmbed(nn.Module):
@@ -921,13 +921,11 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if audio_feature_lengths is None and feature_attention_mask is None: if audio_feature_lengths is None and feature_attention_mask is None:
audio_output_lengths = [] audio_output_lengths = []
elif audio_feature_lengths is not None: elif audio_feature_lengths is not None:
_, audio_output_lens = _get_feat_extract_output_lengths( audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_feature_lengths
)
audio_output_lengths = audio_output_lens.tolist() audio_output_lengths = audio_output_lens.tolist()
elif feature_attention_mask is not None: elif feature_attention_mask is not None:
assert isinstance(feature_attention_mask, torch.Tensor) assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lens = _get_feat_extract_output_lengths( audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1) feature_attention_mask.sum(-1)
) )
audio_output_lengths = audio_output_lens.tolist() audio_output_lengths = audio_output_lens.tolist()
@@ -1111,18 +1109,16 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
audio_input: Qwen2_5OmniAudioFeatureInputs, audio_input: Qwen2_5OmniAudioFeatureInputs,
audio_hashes: list[str] | None = None, audio_hashes: list[str] | None = None,
cached_audio_features: torch.Tensor | None = None, cached_audio_features: torch.Tensor | None = None,
) -> torch.Tensor: ) -> tuple[torch.Tensor, ...]:
input_features = audio_input["input_features"] input_features = audio_input["input_features"]
audio_feature_lengths = audio_input["audio_feature_lengths"] audio_feature_lengths = audio_input["audio_feature_lengths"]
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_feature_lengths
)
audio_outputs = self.audio_tower( audio_outputs = self.audio_tower(
input_features.to(self.audio_tower.dtype), input_features.to(self.audio_tower.dtype),
feature_lens=audio_feature_lengths, feature_lens=audio_feature_lengths,
aftercnn_lens=audio_feat_lengths, aftercnn_lens=audio_output_lengths,
) )
audio_features = audio_outputs.last_hidden_state audio_features = audio_outputs.last_hidden_state
return audio_features.split(audio_output_lengths.tolist()) return audio_features.split(audio_output_lengths.tolist())
@@ -1579,7 +1575,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
+ st_idx + st_idx
) )
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
_, audio_len = _get_feat_extract_output_lengths( audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx] audio_feature_lengths[audio_idx]
) )
llm_pos_ids = ( llm_pos_ids = (
@@ -1700,7 +1696,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
llm_pos_ids_list.append(bos_block) llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block) llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
_, audio_len = _get_feat_extract_output_lengths( audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx] audio_feature_lengths[audio_idx]
) )
audio_llm_pos_ids = ( audio_llm_pos_ids = (