diff --git a/vllm/model_executor/models/funasr.py b/vllm/model_executor/models/funasr.py index 591a0184a..78acca3c2 100644 --- a/vllm/model_executor/models/funasr.py +++ b/vllm/model_executor/models/funasr.py @@ -573,6 +573,8 @@ class Transformer(nn.Module): ) def forward(self, hidden_states: torch.Tensor, ilens: int = 0): + max_len = max(ilens) + hidden_states = hidden_states[:, :max_len, :] batch_size, seq_len, dim = hidden_states.size() chunk_num = (seq_len - 1) // self.k + 1 pad_num = chunk_num * self.k - seq_len diff --git a/vllm/transformers_utils/processors/funasr.py b/vllm/transformers_utils/processors/funasr.py index 1ce653c2e..d7a3c4060 100644 --- a/vllm/transformers_utils/processors/funasr.py +++ b/vllm/transformers_utils/processors/funasr.py @@ -268,6 +268,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): n_fft=400, padding_value=0.0, dither=0.0, + max_length=1000, return_attention_mask=False, **kwargs, ): @@ -279,6 +280,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): **kwargs, ) self.frontend_conf = kwargs.get("frontend_conf", {}) + self.max_length = max_length self.n_fft = n_fft self.hop_length = hop_length self.chunk_length = chunk_length @@ -329,64 +331,41 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): return_token_timestamps: bool | None = None, **kwargs, ) -> BatchFeature: - is_batched = isinstance(raw_speech, (list, tuple)) and ( - isinstance(raw_speech[0], (np.ndarray, tuple, list)) - ) + frontend = WavFrontend(**self.frontend_conf, dither=self.dither) - if is_batched: - raw_speech = [ - np.asarray([speech], dtype=np.float32).T for speech in raw_speech - ] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype( - np.float64 - ): - raw_speech = raw_speech.astype(np.float32) + feats = [] + speech_lengths = [] + fake_token_lengths = [] + for speech in raw_speech: + feature, length = self.extract_fbank( + speech, + data_type=kwargs.get("data_type", "sound"), + frontend=frontend, + is_final=True, + ) + feats.append(feature) + speech_lengths.append(length) + olens = 1 + (length - 3 + 2 * 1) // 2 + olens = 1 + (olens - 3 + 2 * 1) // 2 + fake_token_len = (olens - 1) // 2 + 1 + fake_token_len = torch.clamp(fake_token_len, min=1) + fake_token_lengths.append(fake_token_len) - if not is_batched: - raw_speech = [np.asarray([raw_speech]).T] - - batched_speech = BatchFeature({"input_features": raw_speech}) - - padded_inputs = self.pad( - batched_speech, + feats = torch.concat(feats, dim=0) + batched_speech = self.pad( + BatchFeature({"input_features": feats}), padding=padding, - max_length=max_length if max_length else self.n_samples, + max_length=max_length if max_length else self.max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask or do_normalize, ) - - input_features = padded_inputs.get("input_features").transpose(2, 0, 1) - - frontend = WavFrontend(**self.frontend_conf, dither=self.dither) - input_features, speech_lengths = self.extract_fbank( - input_features[0], - data_type=kwargs.get("data_type", "sound"), - frontend=frontend, - is_final=True, - ) - olens = 1 + (speech_lengths - 3 + 2 * 1) // 2 - olens = 1 + (olens - 3 + 2 * 1) // 2 - fake_token_lengths = (olens - 1) // 2 + 1 - if isinstance(input_features[0], list): - padded_inputs["input_features"] = [ - np.asarray(feature, dtype=np.float32) for feature in input_features - ] - - else: - padded_inputs["input_features"] = input_features - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + batched_speech = batched_speech.convert_to_tensors(return_tensors) - fake_token_lengths = torch.clamp(fake_token_lengths, min=1) - - padded_inputs["speech_lengths"] = speech_lengths - padded_inputs["fake_token_lengths"] = fake_token_lengths - - return padded_inputs + batched_speech["speech_lengths"] = torch.tensor(speech_lengths) + batched_speech["fake_token_lengths"] = torch.concat(fake_token_lengths) + return batched_speech class FunASRProcessor(ProcessorMixin):