[Refactor] GLM-ASR Modeling (#31779)
Signed-off-by: JaredforReal <w13431838023@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -71,14 +71,37 @@ def _get_audio_output_lengths_for_tower(
|
||||
merge_factor: int,
|
||||
conv_params: list[tuple[int, int, int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the output lengths after audio processing.
|
||||
|
||||
The output length accounts for:
|
||||
1. Convolution layers (downsampling)
|
||||
2. Merge factor (further downsampling during projection)
|
||||
|
||||
Args:
|
||||
audio_tower: The audio encoder module
|
||||
audio_lengths: Input feature lengths [batch_size]
|
||||
merge_factor: Factor for merging adjacent features
|
||||
conv_params: List of (padding, kernel_size, stride) for each conv layer
|
||||
|
||||
Returns:
|
||||
Output lengths after all processing [batch_size]
|
||||
"""
|
||||
# First, calculate the output length after convolutions
|
||||
if hasattr(audio_tower, "_get_feat_extract_output_lengths"):
|
||||
_, audio_output_lengths = audio_tower._get_feat_extract_output_lengths(
|
||||
_, conv_output_lengths = audio_tower._get_feat_extract_output_lengths(
|
||||
audio_lengths
|
||||
)
|
||||
return audio_output_lengths
|
||||
return _get_audio_output_lengths_from_lengths(
|
||||
audio_lengths, merge_factor, conv_params
|
||||
)
|
||||
else:
|
||||
conv_output_lengths = audio_lengths
|
||||
for padding, kernel_size, stride in conv_params:
|
||||
conv_output_lengths = _calculate_conv_output_length(
|
||||
conv_output_lengths, padding, kernel_size, stride
|
||||
)
|
||||
|
||||
# Then, apply merge_factor to get final output length
|
||||
# Formula: (conv_output_lengths - merge_factor) // merge_factor + 1
|
||||
return (conv_output_lengths - merge_factor) // merge_factor + 1
|
||||
|
||||
|
||||
def _flatten_audio_features_by_length(
|
||||
|
||||
Reference in New Issue
Block a user