[BugFix] Fix test breakages from transformers 4.45 upgrade (#8829)

This commit is contained in:
Nick Hill
2024-09-27 00:46:43 +01:00
committed by GitHub
parent 71d21c73ab
commit 4b377d6feb
13 changed files with 62 additions and 49 deletions

View File

@@ -3,7 +3,6 @@
Run `pytest tests/models/test_granite.py`.
"""
import pytest
import transformers
from ...utils import check_logprobs_close
@@ -12,9 +11,6 @@ MODELS = [
]
# GraniteForCausalLM will be in transformers >= 4.45
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="granite model test requires transformers >= 4.45")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])