[CI] Add Buildkite (#2355)

This commit is contained in:
Simon Mo
2024-01-14 12:37:58 -08:00
committed by GitHub
parent 9f659bf07f
commit 6e01e8c1c8
13 changed files with 192 additions and 37 deletions

View File

@@ -75,6 +75,8 @@ def test_sampler_all_greedy(seed: int):
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
@@ -111,6 +113,8 @@ def test_sampler_all_random(seed: int):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
@@ -144,6 +148,7 @@ def test_sampler_all_beam(seed: int):
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@@ -198,6 +203,8 @@ def test_sampler_mixed(seed: int):
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
@@ -235,6 +242,8 @@ def test_sampler_logits_processors(seed: int):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_top_k_top_p(seed: int):
@@ -296,3 +305,5 @@ def test_sampler_top_k_top_p(seed: int):
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
del model_runner