[BugFix] Fix test breakages from transformers 4.45 upgrade (#8829)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@@ -596,8 +597,19 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True)
|
||||
warpers = generation_model._get_logits_warper(generation_config, device)
|
||||
assert len(warpers) == 2 # top_p and top_k
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
generation_model.config = MockConfig() # needed by the following method
|
||||
generation_model._prepare_special_tokens(generation_config, device=device)
|
||||
processors = generation_model._get_logits_processor(generation_config,
|
||||
None,
|
||||
None,
|
||||
None, [],
|
||||
device=device)
|
||||
assert len(processors) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
seq_lens: List[int] = []
|
||||
@@ -639,7 +651,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
|
||||
assert sample_probs is not None
|
||||
|
||||
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
Reference in New Issue
Block a user