[Bugfix] handle alignment of arguments in convert_sparse_cross_attention_mask_to_dense (#12347)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Wallas Santos <wallashss@ibm.com> Co-authored-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
@@ -1485,14 +1485,23 @@ def convert_sparse_cross_attention_mask_to_dense(
|
||||
total_length = sum(lengths)
|
||||
total_tiles = sum([sum(tiles) for tiles in num_tiles])
|
||||
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
|
||||
# A list of ranges, range[i] = [start, end] means
|
||||
# if the i-th sample has N tiles in total, the tiles[start, end]
|
||||
# will be used for cross-attention decoding.
|
||||
# A list of ranges, range[i] = [start, end] means that the i-th image will
|
||||
# use tiles[start, end] for cross-attention decoding.
|
||||
tile_range_for_decode = []
|
||||
|
||||
seq_start = 0
|
||||
tile_start = 0
|
||||
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
|
||||
|
||||
# sparse_mask has an [] entry for each sequence that does not have images,
|
||||
# but num_tiles does not have these entries...
|
||||
num_tiles_idx = 0
|
||||
for masks, length in zip(sparse_mask, lengths):
|
||||
if len(masks) == 0:
|
||||
# Text only
|
||||
continue
|
||||
|
||||
tiles = num_tiles[num_tiles_idx]
|
||||
num_tiles_idx += 1
|
||||
ts, td = -1, 0
|
||||
for mask, tile in zip(masks, tiles):
|
||||
if len(mask) != 2:
|
||||
@@ -1512,6 +1521,7 @@ def convert_sparse_cross_attention_mask_to_dense(
|
||||
assert td != 0
|
||||
tile_range_for_decode.append((ts, ts + td))
|
||||
seq_start += length
|
||||
assert num_tiles_idx == len(num_tiles)
|
||||
|
||||
return dense_mask, tile_range_for_decode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user