[ROCm][CI] Fix test_token_classification.py::test_bert_models (#31993)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
from tests.models.utils import softmax
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
|
||||
@@ -21,8 +22,17 @@ def test_bert_models(
|
||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||
|
||||
# Use eager attention on ROCm to avoid HF Transformers flash attention
|
||||
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
|
||||
hf_model_kwargs = {}
|
||||
if current_platform.is_rocm():
|
||||
hf_model_kwargs["attn_implementation"] = "eager"
|
||||
|
||||
with hf_runner(
|
||||
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||
model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForTokenClassification,
|
||||
model_kwargs=hf_model_kwargs,
|
||||
) as hf_model:
|
||||
tokenizer = hf_model.tokenizer
|
||||
hf_outputs = []
|
||||
@@ -36,7 +46,7 @@ def test_bert_models(
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = hf_output.detach().clone().cpu().float()
|
||||
vllm_output = vllm_output.detach().clone().cpu().float()
|
||||
assert torch.allclose(hf_output, vllm_output, 1e-2)
|
||||
torch.testing.assert_close(hf_output, vllm_output, atol=1.2e-2, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
|
||||
@@ -49,8 +59,6 @@ def test_modernbert_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user