FunASR model bugfix (#36633)
Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com> Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>
This commit is contained in:
@@ -573,6 +573,8 @@ class Transformer(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, ilens: int = 0):
|
||||
max_len = max(ilens)
|
||||
hidden_states = hidden_states[:, :max_len, :]
|
||||
batch_size, seq_len, dim = hidden_states.size()
|
||||
chunk_num = (seq_len - 1) // self.k + 1
|
||||
pad_num = chunk_num * self.k - seq_len
|
||||
|
||||
@@ -268,6 +268,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
|
||||
n_fft=400,
|
||||
padding_value=0.0,
|
||||
dither=0.0,
|
||||
max_length=1000,
|
||||
return_attention_mask=False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -279,6 +280,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
|
||||
**kwargs,
|
||||
)
|
||||
self.frontend_conf = kwargs.get("frontend_conf", {})
|
||||
self.max_length = max_length
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.chunk_length = chunk_length
|
||||
@@ -329,64 +331,41 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
|
||||
return_token_timestamps: bool | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
is_batched = isinstance(raw_speech, (list, tuple)) and (
|
||||
isinstance(raw_speech[0], (np.ndarray, tuple, list))
|
||||
)
|
||||
frontend = WavFrontend(**self.frontend_conf, dither=self.dither)
|
||||
|
||||
if is_batched:
|
||||
raw_speech = [
|
||||
np.asarray([speech], dtype=np.float32).T for speech in raw_speech
|
||||
]
|
||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
||||
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
|
||||
np.float64
|
||||
):
|
||||
raw_speech = raw_speech.astype(np.float32)
|
||||
feats = []
|
||||
speech_lengths = []
|
||||
fake_token_lengths = []
|
||||
for speech in raw_speech:
|
||||
feature, length = self.extract_fbank(
|
||||
speech,
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend,
|
||||
is_final=True,
|
||||
)
|
||||
feats.append(feature)
|
||||
speech_lengths.append(length)
|
||||
olens = 1 + (length - 3 + 2 * 1) // 2
|
||||
olens = 1 + (olens - 3 + 2 * 1) // 2
|
||||
fake_token_len = (olens - 1) // 2 + 1
|
||||
fake_token_len = torch.clamp(fake_token_len, min=1)
|
||||
fake_token_lengths.append(fake_token_len)
|
||||
|
||||
if not is_batched:
|
||||
raw_speech = [np.asarray([raw_speech]).T]
|
||||
|
||||
batched_speech = BatchFeature({"input_features": raw_speech})
|
||||
|
||||
padded_inputs = self.pad(
|
||||
batched_speech,
|
||||
feats = torch.concat(feats, dim=0)
|
||||
batched_speech = self.pad(
|
||||
BatchFeature({"input_features": feats}),
|
||||
padding=padding,
|
||||
max_length=max_length if max_length else self.n_samples,
|
||||
max_length=max_length if max_length else self.max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask or do_normalize,
|
||||
)
|
||||
|
||||
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
||||
|
||||
frontend = WavFrontend(**self.frontend_conf, dither=self.dither)
|
||||
input_features, speech_lengths = self.extract_fbank(
|
||||
input_features[0],
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend,
|
||||
is_final=True,
|
||||
)
|
||||
olens = 1 + (speech_lengths - 3 + 2 * 1) // 2
|
||||
olens = 1 + (olens - 3 + 2 * 1) // 2
|
||||
fake_token_lengths = (olens - 1) // 2 + 1
|
||||
if isinstance(input_features[0], list):
|
||||
padded_inputs["input_features"] = [
|
||||
np.asarray(feature, dtype=np.float32) for feature in input_features
|
||||
]
|
||||
|
||||
else:
|
||||
padded_inputs["input_features"] = input_features
|
||||
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
batched_speech = batched_speech.convert_to_tensors(return_tensors)
|
||||
|
||||
fake_token_lengths = torch.clamp(fake_token_lengths, min=1)
|
||||
|
||||
padded_inputs["speech_lengths"] = speech_lengths
|
||||
padded_inputs["fake_token_lengths"] = fake_token_lengths
|
||||
|
||||
return padded_inputs
|
||||
batched_speech["speech_lengths"] = torch.tensor(speech_lengths)
|
||||
batched_speech["fake_token_lengths"] = torch.concat(fake_token_lengths)
|
||||
return batched_speech
|
||||
|
||||
|
||||
class FunASRProcessor(ProcessorMixin):
|
||||
|
||||
Reference in New Issue
Block a user