[Qwen3-Omni] fixed _get_feat_extract_output_lengths function (#31007)
Signed-off-by: Xiong Wang <wangxiongts@163.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -118,7 +118,7 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
|||||||
output_lengths = (
|
output_lengths = (
|
||||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
)
|
)
|
||||||
return feat_lengths, output_lengths
|
return output_lengths
|
||||||
|
|
||||||
|
|
||||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||||
@@ -921,13 +921,11 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
if audio_feature_lengths is None and feature_attention_mask is None:
|
if audio_feature_lengths is None and feature_attention_mask is None:
|
||||||
audio_output_lengths = []
|
audio_output_lengths = []
|
||||||
elif audio_feature_lengths is not None:
|
elif audio_feature_lengths is not None:
|
||||||
_, audio_output_lens = _get_feat_extract_output_lengths(
|
audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
|
||||||
audio_feature_lengths
|
|
||||||
)
|
|
||||||
audio_output_lengths = audio_output_lens.tolist()
|
audio_output_lengths = audio_output_lens.tolist()
|
||||||
elif feature_attention_mask is not None:
|
elif feature_attention_mask is not None:
|
||||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||||
_, audio_output_lens = _get_feat_extract_output_lengths(
|
audio_output_lens = _get_feat_extract_output_lengths(
|
||||||
feature_attention_mask.sum(-1)
|
feature_attention_mask.sum(-1)
|
||||||
)
|
)
|
||||||
audio_output_lengths = audio_output_lens.tolist()
|
audio_output_lengths = audio_output_lens.tolist()
|
||||||
@@ -1111,18 +1109,16 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
|
|||||||
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||||
audio_hashes: list[str] | None = None,
|
audio_hashes: list[str] | None = None,
|
||||||
cached_audio_features: torch.Tensor | None = None,
|
cached_audio_features: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
input_features = audio_input["input_features"]
|
input_features = audio_input["input_features"]
|
||||||
audio_feature_lengths = audio_input["audio_feature_lengths"]
|
audio_feature_lengths = audio_input["audio_feature_lengths"]
|
||||||
|
|
||||||
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
|
audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
|
||||||
audio_feature_lengths
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_outputs = self.audio_tower(
|
audio_outputs = self.audio_tower(
|
||||||
input_features.to(self.audio_tower.dtype),
|
input_features.to(self.audio_tower.dtype),
|
||||||
feature_lens=audio_feature_lengths,
|
feature_lens=audio_feature_lengths,
|
||||||
aftercnn_lens=audio_feat_lengths,
|
aftercnn_lens=audio_output_lengths,
|
||||||
)
|
)
|
||||||
audio_features = audio_outputs.last_hidden_state
|
audio_features = audio_outputs.last_hidden_state
|
||||||
return audio_features.split(audio_output_lengths.tolist())
|
return audio_features.split(audio_output_lengths.tolist())
|
||||||
@@ -1579,7 +1575,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|||||||
+ st_idx
|
+ st_idx
|
||||||
)
|
)
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
_, audio_len = _get_feat_extract_output_lengths(
|
audio_len = _get_feat_extract_output_lengths(
|
||||||
audio_feature_lengths[audio_idx]
|
audio_feature_lengths[audio_idx]
|
||||||
)
|
)
|
||||||
llm_pos_ids = (
|
llm_pos_ids = (
|
||||||
@@ -1700,7 +1696,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|||||||
llm_pos_ids_list.append(bos_block)
|
llm_pos_ids_list.append(bos_block)
|
||||||
llm_pos_ids_list.append(bos_block)
|
llm_pos_ids_list.append(bos_block)
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
_, audio_len = _get_feat_extract_output_lengths(
|
audio_len = _get_feat_extract_output_lengths(
|
||||||
audio_feature_lengths[audio_idx]
|
audio_feature_lengths[audio_idx]
|
||||||
)
|
)
|
||||||
audio_llm_pos_ids = (
|
audio_llm_pos_ids = (
|
||||||
|
|||||||
Reference in New Issue
Block a user