[BUGFIX]Fix Qwen-Omni models audio max_token_per_item estimation error leading to encoder_cache_size is 0 (#35994)

Signed-off-by: Miao, Avery <avery.miao@intel.com>
(cherry picked from commit e998fa76b9)
This commit is contained in:
Avery Miao
2026-03-06 01:16:29 +08:00
committed by khluu
parent fa78ec8a72
commit b7a423cb01
3 changed files with 86 additions and 0 deletions

View File

@@ -353,6 +353,39 @@ class Qwen2_5OmniThinkerProcessingInfo(
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None, "image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int] | None:
mm_counts = mm_counts or {}
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
mm_max_tokens: dict[str, int] = {}
if requested_modalities & {"image", "video"}:
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens.update(
{
m: vl_tokens[m]
for m in ["image", "video"]
if m in requested_modalities
}
)
if "audio" in requested_modalities:
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens["audio"] = audio_tokens["audio"]
return mm_max_tokens
class Qwen2_5OmniThinkerDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]

View File

@@ -179,6 +179,26 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
mm_counts = mm_counts or {}
if mm_counts.get("audio", 0) <= 0:
return {}
feature_extractor = self.get_feature_extractor()
chunk_length = min(feature_extractor.chunk_length, 30)
audio_len = int(chunk_length * feature_extractor.sampling_rate)
hop_length = feature_extractor.hop_length
max_mel_seq_len = audio_len // hop_length
input_lengths = torch.tensor([max_mel_seq_len], dtype=torch.long)
_, output_lengths = _get_feat_extract_output_lengths(input_lengths)
return {"audio": int(output_lengths.item())}
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:

View File

@@ -1163,6 +1163,39 @@ class Qwen3OmniMoeThinkerProcessingInfo(
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None, "image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int] | None:
mm_counts = mm_counts or {}
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
mm_max_tokens: dict[str, int] = {}
if requested_modalities & {"image", "video"}:
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens.update(
{
m: vl_tokens[m]
for m in ["image", "video"]
if m in requested_modalities
}
)
if "audio" in requested_modalities:
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
self,
seq_len=seq_len,
mm_counts=mm_counts,
)
mm_max_tokens["audio"] = audio_tokens["audio"]
return mm_max_tokens
Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder