[Bugfix] handle alignment of arguments in convert_sparse_cross_attention_mask_to_dense (#12347)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Wallas Santos <wallashss@ibm.com>
Co-authored-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
Travis Johnson
2025-01-29 01:54:35 -07:00
committed by GitHub
parent ef001d98ef
commit 036ca94c25
2 changed files with 222 additions and 4 deletions

View File

@@ -1485,14 +1485,23 @@ def convert_sparse_cross_attention_mask_to_dense(
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means
# if the i-th sample has N tiles in total, the tiles[start, end]
# will be used for cross-attention decoding.
# A list of ranges, range[i] = [start, end] means that the i-th image will
# use tiles[start, end] for cross-attention decoding.
tile_range_for_decode = []
seq_start = 0
tile_start = 0
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
# sparse_mask has an [] entry for each sequence that does not have images,
# but num_tiles does not have these entries...
num_tiles_idx = 0
for masks, length in zip(sparse_mask, lengths):
if len(masks) == 0:
# Text only
continue
tiles = num_tiles[num_tiles_idx]
num_tiles_idx += 1
ts, td = -1, 0
for mask, tile in zip(masks, tiles):
if len(mask) != 2:
@@ -1512,6 +1521,7 @@ def convert_sparse_cross_attention_mask_to_dense(
assert td != 0
tile_range_for_decode.append((ts, ts + td))
seq_start += length
assert num_tiles_idx == len(num_tiles)
return dense_mask, tile_range_for_decode