Re-enable the 80 char line width limit (#3305)
This commit is contained in:
@@ -255,9 +255,10 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
if metadata.sampling_params.use_beam_search:
|
||||
continue
|
||||
|
||||
if metadata.sampling_params.seed is not None \
|
||||
and expected_tokens[i] is None:
|
||||
# Record seeded random result to compare with results of second invocation
|
||||
if (metadata.sampling_params.seed is not None
|
||||
and expected_tokens[i] is None):
|
||||
# Record seeded random result to compare with results of
|
||||
# second invocation
|
||||
expected_tokens[i] = [
|
||||
nth_output.output_token
|
||||
for nth_output in sequence_output.samples
|
||||
@@ -265,11 +266,13 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
continue
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
|
||||
if (metadata.sampling_params.temperature == 0
|
||||
or metadata.sampling_params.seed is not None):
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
assert nth_output.output_token == expected_tokens[i][n]
|
||||
else:
|
||||
# For non-seeded random check that one of the high-logit tokens were chosen
|
||||
# For non-seeded random check that one of the high-logit
|
||||
# tokens were chosen
|
||||
assert nth_output.output_token in expected_tokens[i]
|
||||
|
||||
# Test batch
|
||||
@@ -284,8 +287,8 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||
|
||||
# This time, results of seeded random samples will be compared with the corresponding
|
||||
# sample in the pre-shuffled batch
|
||||
# This time, results of seeded random samples will be compared with
|
||||
# the corresponding sample in the pre-shuffled batch
|
||||
test_sampling(model_runner)
|
||||
|
||||
del model_runner
|
||||
|
||||
Reference in New Issue
Block a user