[Refactor] Remove dead private func _fp8_perm and _extract_mask_for_item (#35068)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -130,39 +130,3 @@ def _group_audio_embeddings(
|
||||
grouped_embeddings.append(torch.cat(audio_chunks, dim=0))
|
||||
current_idx += count
|
||||
return tuple(grouped_embeddings)
|
||||
|
||||
|
||||
def _normalize_to_tensor(mask: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
|
||||
"""Convert mask to tensor, handling both list and tensor formats."""
|
||||
if isinstance(mask, list):
|
||||
return (
|
||||
torch.stack(mask)
|
||||
if mask and isinstance(mask[0], torch.Tensor)
|
||||
else torch.tensor(mask)
|
||||
)
|
||||
return mask
|
||||
|
||||
|
||||
def _extract_mask_for_item(
|
||||
feature_attention_mask: torch.Tensor | list[torch.Tensor],
|
||||
chunk_counts: torch.Tensor | list[int] | None,
|
||||
item_idx: int,
|
||||
) -> torch.Tensor:
|
||||
"""Extract attention mask for a specific audio item."""
|
||||
if chunk_counts is None:
|
||||
# Single item per audio
|
||||
mask = feature_attention_mask[item_idx]
|
||||
if isinstance(feature_attention_mask, torch.Tensor):
|
||||
return mask.unsqueeze(0)
|
||||
return _normalize_to_tensor(mask)
|
||||
|
||||
# Multiple chunks per audio: calculate slice indices
|
||||
counts = _as_list_chunk_counts(chunk_counts)
|
||||
start_idx = sum(counts[:item_idx])
|
||||
end_idx = start_idx + counts[item_idx]
|
||||
|
||||
# Extract slice
|
||||
if isinstance(feature_attention_mask, torch.Tensor):
|
||||
return feature_attention_mask[start_idx:end_idx]
|
||||
mask_slice = feature_attention_mask[start_idx:end_idx]
|
||||
return _normalize_to_tensor(mask_slice)
|
||||
|
||||
Reference in New Issue
Block a user