refactor funasr model. (#36108)

Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
AllenDou
2026-03-06 00:07:37 +08:00
committed by GitHub
parent 7196348157
commit 3ee68590c7
3 changed files with 24 additions and 57 deletions

View File

@@ -51,7 +51,6 @@ from vllm.multimodal.processing import (
)
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
@@ -611,6 +610,10 @@ class FunASRAudioInputs(TensorSchema):
list[torch.Tensor] | None,
TensorShape("b"),
]
fake_token_lengths: Annotated[
list[torch.Tensor] | None,
TensorShape("b"),
]
class FunASREncoder(nn.Module):
@@ -732,9 +735,6 @@ class FunASRProcessingInfo(BaseProcessingInfo):
def get_target_channels(self) -> int:
return 1
def get_num_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions
class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
@@ -798,7 +798,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
speech_lengths=MultiModalFieldConfig.batched("audio"),
fake_token_len=MultiModalFieldConfig.batched("audio"),
fake_token_lengths=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
@@ -812,22 +812,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
out_mm_data = out_mm_kwargs.get_data()
fake_token_len = out_mm_data.get("fake_token_len")
if fake_token_len is None:
fake_token_lengths = out_mm_data.get("fake_token_lengths")
if fake_token_lengths is None:
audio_output_lengths = []
else:
assert isinstance(fake_token_len, torch.Tensor)
assert isinstance(fake_token_lengths, torch.Tensor)
audio_output_lengths = fake_token_len.tolist()
audio_output_lengths = fake_token_lengths.tolist()
def get_replacement_qwen2_audio(item_idx: int):
if audio_output_lengths:
num_features = audio_output_lengths[item_idx]
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
num_features = audio_output_lengths[item_idx]
return [audio_token_id] * num_features
return [
@@ -847,21 +841,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
class FunASRForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"linear_q.": "q_proj.",
"linear_k.": "k_proj.",
"linear_v.": "v_proj.",
"linear_out.": "out_proj.",
"audio_adaptor.": "model.encoder.audio_adaptor.",
"audio_encoder.": "model.encoder.audio_encoder.",
"llm.model.": "model.decoder.",
"llm.lm_head": "lm_head",
}
)
@@ -969,9 +958,6 @@ class FunASRForConditionalGeneration(
)
return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
@@ -1002,15 +988,12 @@ class FunASRForConditionalGeneration(
def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs:
input_features = kwargs.pop("input_features", None)
speech_lengths = kwargs.pop("speech_lengths", None)
if input_features is not None:
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
if speech_lengths is not None:
speech_lengths = json_map_leaves(lambda x: x.to(self.dtype), speech_lengths)
fake_token_lengths = kwargs.pop("fake_token_lengths", None)
return FunASRAudioInputs(
input_features=input_features, speech_lengths=speech_lengths
input_features=input_features,
speech_lengths=speech_lengths,
fake_token_lengths=fake_token_lengths,
)
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -1022,22 +1005,4 @@ class FunASRForConditionalGeneration(
self,
)
# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _create_fake_bias_for_k_proj(
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros.
"""
for name, weight in weights:
if name.endswith(".k_proj.weight"):
bias = torch.zeros(weight.size(0))
bias_name = name.replace("weight", "bias")
yield from [(name, weight), (bias_name, bias)]
else:
yield name, weight

View File

@@ -1794,7 +1794,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return []
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
# tensor corresponding to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary

View File

@@ -370,7 +370,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
)
olens = 1 + (speech_lengths - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len = (olens - 1) // 2 + 1
fake_token_lengths = (olens - 1) // 2 + 1
if isinstance(input_features[0], list):
padded_inputs["input_features"] = [
np.asarray(feature, dtype=np.float32) for feature in input_features
@@ -382,8 +382,10 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
fake_token_lengths = torch.clamp(fake_token_lengths, min=1)
padded_inputs["speech_lengths"] = speech_lengths
padded_inputs["fake_token_len"] = fake_token_len
padded_inputs["fake_token_lengths"] = fake_token_lengths
return padded_inputs
@@ -471,7 +473,7 @@ class FunASRProcessor(ProcessorMixin):
for sample in text:
replace_str = []
while self.audio_token in sample:
num_audio_tokens = inputs["fake_token_len"].item()
num_audio_tokens = inputs["fake_token_lengths"].item()
expanded_audio_token = self.audio_token * num_audio_tokens