From 3ee68590c7fafe05f1db1f1bee019c7b3a83ec96 Mon Sep 17 00:00:00 2001 From: AllenDou Date: Fri, 6 Mar 2026 00:07:37 +0800 Subject: [PATCH] refactor funasr model. (#36108) Signed-off-by: zixiao Co-authored-by: zixiao Co-authored-by: Isotr0py --- vllm/model_executor/models/funasr.py | 71 +++++-------------- .../models/qwen3_omni_moe_thinker.py | 2 +- .../processors/funasr_processor.py | 8 ++- 3 files changed, 24 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/models/funasr.py b/vllm/model_executor/models/funasr.py index 25ede72f1..de2e4409e 100644 --- a/vllm/model_executor/models/funasr.py +++ b/vllm/model_executor/models/funasr.py @@ -51,7 +51,6 @@ from vllm.multimodal.processing import ( ) from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -611,6 +610,10 @@ class FunASRAudioInputs(TensorSchema): list[torch.Tensor] | None, TensorShape("b"), ] + fake_token_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] class FunASREncoder(nn.Module): @@ -732,9 +735,6 @@ class FunASRProcessingInfo(BaseProcessingInfo): def get_target_channels(self) -> int: return 1 - def get_num_audio_tokens(self) -> int: - return self.get_hf_config().max_source_positions - class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -798,7 +798,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): return dict( input_features=MultiModalFieldConfig.batched("audio"), speech_lengths=MultiModalFieldConfig.batched("audio"), - fake_token_len=MultiModalFieldConfig.batched("audio"), + fake_token_lengths=MultiModalFieldConfig.batched("audio"), ) def _get_prompt_updates( @@ -812,22 +812,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): out_mm_data = out_mm_kwargs.get_data() - fake_token_len = out_mm_data.get("fake_token_len") - if fake_token_len is None: + fake_token_lengths = out_mm_data.get("fake_token_lengths") + if fake_token_lengths is None: audio_output_lengths = [] else: - assert isinstance(fake_token_len, torch.Tensor) + assert isinstance(fake_token_lengths, torch.Tensor) - audio_output_lengths = fake_token_len.tolist() + audio_output_lengths = fake_token_lengths.tolist() def get_replacement_qwen2_audio(item_idx: int): - if audio_output_lengths: - num_features = audio_output_lengths[item_idx] - else: - audio_embeds = out_mm_data["audio_embeds"][item_idx] - assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor" - num_features = audio_embeds.shape[0] - + num_features = audio_output_lengths[item_idx] return [audio_token_id] * num_features return [ @@ -847,21 +841,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): class FunASRForConditionalGeneration( nn.Module, SupportsTranscription, SupportsMultiModal ): - packed_modules_mapping = { - "self_attn.qkv_proj": [ - "self_attn.q_proj", - "self_attn.k_proj", - "self_attn.v_proj", - ], - "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], - } - hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ "linear_q.": "q_proj.", "linear_k.": "k_proj.", "linear_v.": "v_proj.", "linear_out.": "out_proj.", + "audio_adaptor.": "model.encoder.audio_adaptor.", + "audio_encoder.": "model.encoder.audio_encoder.", + "llm.model.": "model.decoder.", + "llm.lm_head": "lm_head", } ) @@ -969,9 +958,6 @@ class FunASRForConditionalGeneration( ) return decoder_outputs - def get_language_model(self) -> torch.nn.Module: - return self.model.decoder - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) @@ -1002,15 +988,12 @@ class FunASRForConditionalGeneration( def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs: input_features = kwargs.pop("input_features", None) speech_lengths = kwargs.pop("speech_lengths", None) - - if input_features is not None: - input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features) - - if speech_lengths is not None: - speech_lengths = json_map_leaves(lambda x: x.to(self.dtype), speech_lengths) + fake_token_lengths = kwargs.pop("fake_token_lengths", None) return FunASRAudioInputs( - input_features=input_features, speech_lengths=speech_lengths + input_features=input_features, + speech_lengths=speech_lengths, + fake_token_lengths=fake_token_lengths, ) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -1022,22 +1005,4 @@ class FunASRForConditionalGeneration( self, ) - # add fake zeros bias for k_proj to state_dict - weights = _create_fake_bias_for_k_proj(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - -def _create_fake_bias_for_k_proj( - weights: Iterable[tuple[str, torch.Tensor]], -) -> Iterable[tuple[str, torch.Tensor]]: - """ - Create full zeros bias for k_proj weight in self-attn and x-attn layers. - So that the bias for k_proj in qkv_proj can be initialized with zeros. - """ - for name, weight in weights: - if name.endswith(".k_proj.weight"): - bias = torch.zeros(weight.size(0)) - bias_name = name.replace("weight", "bias") - yield from [(name, weight), (bias_name, bias)] - else: - yield name, weight diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 1e6348b72..a6fcc74fa 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -1794,7 +1794,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/transformers_utils/processors/funasr_processor.py b/vllm/transformers_utils/processors/funasr_processor.py index c4cb2a2c4..bb6fe69ac 100644 --- a/vllm/transformers_utils/processors/funasr_processor.py +++ b/vllm/transformers_utils/processors/funasr_processor.py @@ -370,7 +370,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): ) olens = 1 + (speech_lengths - 3 + 2 * 1) // 2 olens = 1 + (olens - 3 + 2 * 1) // 2 - fake_token_len = (olens - 1) // 2 + 1 + fake_token_lengths = (olens - 1) // 2 + 1 if isinstance(input_features[0], list): padded_inputs["input_features"] = [ np.asarray(feature, dtype=np.float32) for feature in input_features @@ -382,8 +382,10 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + fake_token_lengths = torch.clamp(fake_token_lengths, min=1) + padded_inputs["speech_lengths"] = speech_lengths - padded_inputs["fake_token_len"] = fake_token_len + padded_inputs["fake_token_lengths"] = fake_token_lengths return padded_inputs @@ -471,7 +473,7 @@ class FunASRProcessor(ProcessorMixin): for sample in text: replace_str = [] while self.audio_token in sample: - num_audio_tokens = inputs["fake_token_len"].item() + num_audio_tokens = inputs["fake_token_lengths"].item() expanded_audio_token = self.audio_token * num_audio_tokens