Re-enable the 80 char line width limit (#3305)

This commit is contained in:
Zhuohan Li
2024-03-10 19:49:14 -07:00
committed by GitHub
parent 4b59f00e91
commit 2f8844ba08
67 changed files with 557 additions and 528 deletions

View File

@@ -255,9 +255,10 @@ def test_sampler_mixed(seed: int, device: str):
if metadata.sampling_params.use_beam_search:
continue
if metadata.sampling_params.seed is not None \
and expected_tokens[i] is None:
# Record seeded random result to compare with results of second invocation
if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None):
# Record seeded random result to compare with results of
# second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
@@ -265,11 +266,13 @@ def test_sampler_mixed(seed: int, device: str):
continue
for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit tokens were chosen
# For non-seeded random check that one of the high-logit
# tokens were chosen
assert nth_output.output_token in expected_tokens[i]
# Test batch
@@ -284,8 +287,8 @@ def test_sampler_mixed(seed: int, device: str):
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
test_sampling(model_runner)
del model_runner