[Bugfix] Fix Var Length Batched Padding in Granite Speech (#31906)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -672,7 +672,13 @@ class GraniteSpeechForConditionalGeneration(
|
||||
|
||||
else:
|
||||
# Otherwise we have a list of tensors, which are almost certainly
|
||||
# differing in their respective numbers of audio features;
|
||||
# differing in their respective numbers of audio features; when
|
||||
# passed as a batch, we expect a list of 2D var len input features
|
||||
# so unsqueeze them.
|
||||
input_features = [
|
||||
feat.unsqueeze(dim=0) for feat in input_features if feat.ndim == 2
|
||||
]
|
||||
|
||||
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
|
||||
input_features = self._pad_and_stack_input_features(
|
||||
input_features,
|
||||
@@ -724,13 +730,12 @@ class GraniteSpeechForConditionalGeneration(
|
||||
|
||||
Args:
|
||||
input_features: list[torch.Tensor]
|
||||
Input features to be coerced into a tensor.
|
||||
3D Input features to be coerced into a tensor.
|
||||
Returns:
|
||||
torch.Tensor: Tensor of shape [bsz, num_features, 160], where
|
||||
num_features is the max number of features of any entry in the
|
||||
batch.
|
||||
"""
|
||||
# Input features are of shape [bsz, num_features, 160]
|
||||
feat_lens = [feats.shape[1] for feats in input_features]
|
||||
padding = [max(feat_lens) - length for length in feat_lens]
|
||||
# TODO (Alex) - Validate that it's okay to zero pad like this;
|
||||
|
||||
Reference in New Issue
Block a user