[Spec Decoding] Streamline batch expansion tensor manipulation (#7851)
This commit is contained in:
@@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
|
||||
|
||||
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 0]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=True)
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||
@@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||
|
||||
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 2]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=False)
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||
@@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||
|
||||
|
||||
def test_empty_inputs():
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
[], [], select_proposal_len_zero=True)
|
||||
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
@@ -95,10 +92,9 @@ def test_empty_inputs():
|
||||
|
||||
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 0, 0]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=False)
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
@@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||
|
||||
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [1, 1, 1]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=True)
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
Reference in New Issue
Block a user