[Refactor] Remove dead private func _fp8_perm and _extract_mask_for_item (#35068)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-02-23 08:05:20 -05:00
committed by GitHub
parent 103e614b14
commit 7f40e9e516
2 changed files with 0 additions and 46 deletions

View File

@@ -130,39 +130,3 @@ def _group_audio_embeddings(
grouped_embeddings.append(torch.cat(audio_chunks, dim=0))
current_idx += count
return tuple(grouped_embeddings)
def _normalize_to_tensor(mask: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
"""Convert mask to tensor, handling both list and tensor formats."""
if isinstance(mask, list):
return (
torch.stack(mask)
if mask and isinstance(mask[0], torch.Tensor)
else torch.tensor(mask)
)
return mask
def _extract_mask_for_item(
feature_attention_mask: torch.Tensor | list[torch.Tensor],
chunk_counts: torch.Tensor | list[int] | None,
item_idx: int,
) -> torch.Tensor:
"""Extract attention mask for a specific audio item."""
if chunk_counts is None:
# Single item per audio
mask = feature_attention_mask[item_idx]
if isinstance(feature_attention_mask, torch.Tensor):
return mask.unsqueeze(0)
return _normalize_to_tensor(mask)
# Multiple chunks per audio: calculate slice indices
counts = _as_list_chunk_counts(chunk_counts)
start_idx = sum(counts[:item_idx])
end_idx = start_idx + counts[item_idx]
# Extract slice
if isinstance(feature_attention_mask, torch.Tensor):
return feature_attention_mask[start_idx:end_idx]
mask_slice = feature_attention_mask[start_idx:end_idx]
return _normalize_to_tensor(mask_slice)