[BugFix] Fix test breakages from transformers 4.45 upgrade (#8829)

This commit is contained in:
Nick Hill
2024-09-27 00:46:43 +01:00
committed by GitHub
parent 71d21c73ab
commit 4b377d6feb
13 changed files with 62 additions and 49 deletions

View File

@@ -1,5 +1,6 @@
import itertools
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
@@ -596,8 +597,19 @@ def test_sampler_top_k_top_p(seed: int, device: str):
generation_config = GenerationConfig(top_k=top_k,
top_p=top_p,
do_sample=True)
warpers = generation_model._get_logits_warper(generation_config, device)
assert len(warpers) == 2 # top_p and top_k
@dataclass
class MockConfig:
is_encoder_decoder: bool = False
generation_model.config = MockConfig() # needed by the following method
generation_model._prepare_special_tokens(generation_config, device=device)
processors = generation_model._get_logits_processor(generation_config,
None,
None,
None, [],
device=device)
assert len(processors) == 2 # top_p and top_k
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
@@ -639,7 +651,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert sample_probs is not None
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))