[core] remove beam search from the core (#9105)
This commit is contained in:
@@ -159,26 +159,6 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_beam(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
)
|
||||
_do_sample(batch_size, fake_logits, sampler, sampling_params, device)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
# when handling an all-beam search case.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
@@ -479,7 +459,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
seq_lens: List[int] = []
|
||||
for i in range(batch_size):
|
||||
expected: Optional[List[int]] = None
|
||||
sampling_type = random.randint(0, 3)
|
||||
sampling_type = random.randint(0, 2)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
|
||||
@@ -498,10 +478,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected = list(range(i, i + n))
|
||||
else:
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
use_beam_search=True,
|
||||
best_of=2)
|
||||
|
||||
expected_tokens.append(expected)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@@ -530,9 +507,6 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if metadata.sampling_params.use_beam_search:
|
||||
continue
|
||||
|
||||
if (metadata.sampling_params.seed is not None
|
||||
and expected_tokens[i] is None):
|
||||
# Record seeded random result to compare with results of
|
||||
|
||||
Reference in New Issue
Block a user