[core] remove beam search from the core (#9105)

This commit is contained in:
youkaichao
2024-10-06 22:47:04 -07:00
committed by GitHub
parent c8f26bb636
commit 18b296fdb2
25 changed files with 98 additions and 596 deletions

View File

@@ -159,26 +159,6 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
assert first_sampler_output == second_sampler_output
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, sampling_params, device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
@@ -479,7 +459,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_lens: List[int] = []
for i in range(batch_size):
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
@@ -498,10 +478,7 @@ def test_sampler_mixed(seed: int, device: str):
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
@@ -530,9 +507,6 @@ def test_sampler_mixed(seed: int, device: str):
zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None
if metadata.sampling_params.use_beam_search:
continue
if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None):
# Record seeded random result to compare with results of