[Bugfix] Fix Var Length Batched Padding in Granite Speech (#31906)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks
2026-01-09 03:28:43 -07:00
committed by GitHub
parent bde38c11df
commit dc77cb7129

View File

@@ -672,7 +672,13 @@ class GraniteSpeechForConditionalGeneration(
else:
# Otherwise we have a list of tensors, which are almost certainly
# differing in their respective numbers of audio features;
# differing in their respective numbers of audio features; when
# passed as a batch, we expect a list of 2D var len input features
# so unsqueeze them.
input_features = [
feat.unsqueeze(dim=0) for feat in input_features if feat.ndim == 2
]
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
input_features = self._pad_and_stack_input_features(
input_features,
@@ -724,13 +730,12 @@ class GraniteSpeechForConditionalGeneration(
Args:
input_features: list[torch.Tensor]
Input features to be coerced into a tensor.
3D Input features to be coerced into a tensor.
Returns:
torch.Tensor: Tensor of shape [bsz, num_features, 160], where
num_features is the max number of features of any entry in the
batch.
"""
# Input features are of shape [bsz, num_features, 160]
feat_lens = [feats.shape[1] for feats in input_features]
padding = [max(feat_lens) - length for length in feat_lens]
# TODO (Alex) - Validate that it's okay to zero pad like this;