refactor funasr model. (#36108)
Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com> Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -51,7 +51,6 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
@@ -611,6 +610,10 @@ class FunASRAudioInputs(TensorSchema):
|
||||
list[torch.Tensor] | None,
|
||||
TensorShape("b"),
|
||||
]
|
||||
fake_token_lengths: Annotated[
|
||||
list[torch.Tensor] | None,
|
||||
TensorShape("b"),
|
||||
]
|
||||
|
||||
|
||||
class FunASREncoder(nn.Module):
|
||||
@@ -732,9 +735,6 @@ class FunASRProcessingInfo(BaseProcessingInfo):
|
||||
def get_target_channels(self) -> int:
|
||||
return 1
|
||||
|
||||
def get_num_audio_tokens(self) -> int:
|
||||
return self.get_hf_config().max_source_positions
|
||||
|
||||
|
||||
class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
@@ -798,7 +798,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
|
||||
return dict(
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
speech_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
fake_token_len=MultiModalFieldConfig.batched("audio"),
|
||||
fake_token_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -812,22 +812,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
|
||||
fake_token_len = out_mm_data.get("fake_token_len")
|
||||
if fake_token_len is None:
|
||||
fake_token_lengths = out_mm_data.get("fake_token_lengths")
|
||||
if fake_token_lengths is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
assert isinstance(fake_token_len, torch.Tensor)
|
||||
assert isinstance(fake_token_lengths, torch.Tensor)
|
||||
|
||||
audio_output_lengths = fake_token_len.tolist()
|
||||
audio_output_lengths = fake_token_lengths.tolist()
|
||||
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
if audio_output_lengths:
|
||||
num_features = audio_output_lengths[item_idx]
|
||||
else:
|
||||
audio_embeds = out_mm_data["audio_embeds"][item_idx]
|
||||
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
|
||||
num_features = audio_embeds.shape[0]
|
||||
|
||||
num_features = audio_output_lengths[item_idx]
|
||||
return [audio_token_id] * num_features
|
||||
|
||||
return [
|
||||
@@ -847,21 +841,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
|
||||
class FunASRForConditionalGeneration(
|
||||
nn.Module, SupportsTranscription, SupportsMultiModal
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
],
|
||||
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
||||
}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"linear_q.": "q_proj.",
|
||||
"linear_k.": "k_proj.",
|
||||
"linear_v.": "v_proj.",
|
||||
"linear_out.": "out_proj.",
|
||||
"audio_adaptor.": "model.encoder.audio_adaptor.",
|
||||
"audio_encoder.": "model.encoder.audio_encoder.",
|
||||
"llm.model.": "model.decoder.",
|
||||
"llm.lm_head": "lm_head",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -969,9 +958,6 @@ class FunASRForConditionalGeneration(
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model.decoder
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
@@ -1002,15 +988,12 @@ class FunASRForConditionalGeneration(
|
||||
def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
speech_lengths = kwargs.pop("speech_lengths", None)
|
||||
|
||||
if input_features is not None:
|
||||
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
|
||||
|
||||
if speech_lengths is not None:
|
||||
speech_lengths = json_map_leaves(lambda x: x.to(self.dtype), speech_lengths)
|
||||
fake_token_lengths = kwargs.pop("fake_token_lengths", None)
|
||||
|
||||
return FunASRAudioInputs(
|
||||
input_features=input_features, speech_lengths=speech_lengths
|
||||
input_features=input_features,
|
||||
speech_lengths=speech_lengths,
|
||||
fake_token_lengths=fake_token_lengths,
|
||||
)
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1022,22 +1005,4 @@ class FunASRForConditionalGeneration(
|
||||
self,
|
||||
)
|
||||
|
||||
# add fake zeros bias for k_proj to state_dict
|
||||
weights = _create_fake_bias_for_k_proj(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
def _create_fake_bias_for_k_proj(
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
|
||||
So that the bias for k_proj in qkv_proj can be initialized with zeros.
|
||||
"""
|
||||
for name, weight in weights:
|
||||
if name.endswith(".k_proj.weight"):
|
||||
bias = torch.zeros(weight.size(0))
|
||||
bias_name = name.replace("weight", "bias")
|
||||
yield from [(name, weight), (bias_name, bias)]
|
||||
else:
|
||||
yield name, weight
|
||||
|
||||
@@ -1794,7 +1794,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
return []
|
||||
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
# tensor corresponding to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
|
||||
@@ -370,7 +370,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
|
||||
)
|
||||
olens = 1 + (speech_lengths - 3 + 2 * 1) // 2
|
||||
olens = 1 + (olens - 3 + 2 * 1) // 2
|
||||
fake_token_len = (olens - 1) // 2 + 1
|
||||
fake_token_lengths = (olens - 1) // 2 + 1
|
||||
if isinstance(input_features[0], list):
|
||||
padded_inputs["input_features"] = [
|
||||
np.asarray(feature, dtype=np.float32) for feature in input_features
|
||||
@@ -382,8 +382,10 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
fake_token_lengths = torch.clamp(fake_token_lengths, min=1)
|
||||
|
||||
padded_inputs["speech_lengths"] = speech_lengths
|
||||
padded_inputs["fake_token_len"] = fake_token_len
|
||||
padded_inputs["fake_token_lengths"] = fake_token_lengths
|
||||
|
||||
return padded_inputs
|
||||
|
||||
@@ -471,7 +473,7 @@ class FunASRProcessor(ProcessorMixin):
|
||||
for sample in text:
|
||||
replace_str = []
|
||||
while self.audio_token in sample:
|
||||
num_audio_tokens = inputs["fake_token_len"].item()
|
||||
num_audio_tokens = inputs["fake_token_lengths"].item()
|
||||
|
||||
expanded_audio_token = self.audio_token * num_audio_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user