Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@@ -16,7 +17,7 @@ from vllm import LLM, SamplingParams
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(monkeypatch):
|
||||
"""Only run on vLLM v1."""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
|
||||
def _generate(
|
||||
@@ -49,25 +50,24 @@ class TestOneTokenBadWord:
|
||||
TARGET_TOKEN = "you"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||
add_prefix_space=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.MODEL, add_prefix_space=True
|
||||
)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id = self._encode(self.TARGET_TOKEN,
|
||||
add_special_tokens=False)[0]
|
||||
self.target_token_id = self._encode(
|
||||
self.TARGET_TOKEN, add_special_tokens=False
|
||||
)[0]
|
||||
|
||||
def test_one_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[0] == self.target_token_id
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN])
|
||||
output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN])
|
||||
assert self.target_token_id not in output_token_ids
|
||||
|
||||
def _generate(self,
|
||||
llm: LLM,
|
||||
bad_words: Optional[list[str]] = None) -> list[int]:
|
||||
def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]:
|
||||
return _generate(
|
||||
llm=llm,
|
||||
prompt=self.PROMPT,
|
||||
@@ -75,11 +75,8 @@ class TestOneTokenBadWord:
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
def _encode(self,
|
||||
prompt: str,
|
||||
add_special_tokens: bool = True) -> list[int]:
|
||||
return self.tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens).input_ids
|
||||
def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
|
||||
return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
|
||||
|
||||
|
||||
class TestTwoTokenBadWord:
|
||||
@@ -92,72 +89,80 @@ class TestTwoTokenBadWord:
|
||||
NEIGHBOUR_TOKEN2 = "older"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||
add_prefix_space=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.MODEL, add_prefix_space=True
|
||||
)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
|
||||
add_special_tokens=False)[0]
|
||||
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
|
||||
add_special_tokens=False)[0]
|
||||
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
|
||||
add_special_tokens=False)[0]
|
||||
self.target_token_id1 = self._encode(
|
||||
self.TARGET_TOKEN1, add_special_tokens=False
|
||||
)[0]
|
||||
self.target_token_id2 = self._encode(
|
||||
self.TARGET_TOKEN2, add_special_tokens=False
|
||||
)[0]
|
||||
self.neighbour_token_id2 = self._encode(
|
||||
self.NEIGHBOUR_TOKEN2, add_special_tokens=False
|
||||
)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
self.target_token_id1,
|
||||
self.target_token_id2,
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN1])
|
||||
output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN1])
|
||||
assert self.target_token_id1 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN2])
|
||||
output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN2])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert self.target_token_id2 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
|
||||
llm, bad_words=[f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}"]
|
||||
)
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
self.target_token_id1,
|
||||
self.target_token_id2,
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.target_token_id2])
|
||||
output_token_ids, [self.target_token_id1, self.target_token_id2]
|
||||
)
|
||||
# Model dependent behaviour
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.neighbour_token_id2
|
||||
self.target_token_id1,
|
||||
self.neighbour_token_id2,
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm,
|
||||
bad_words=[
|
||||
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
|
||||
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
|
||||
])
|
||||
f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}",
|
||||
f"{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}",
|
||||
],
|
||||
)
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
self.target_token_id1,
|
||||
self.target_token_id2,
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.target_token_id2])
|
||||
output_token_ids, [self.target_token_id1, self.target_token_id2]
|
||||
)
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.neighbour_token_id2
|
||||
self.target_token_id1,
|
||||
self.neighbour_token_id2,
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.neighbour_token_id2])
|
||||
assert ((self.target_token_id2 in output_token_ids)
|
||||
or (self.neighbour_token_id2 in output_token_ids))
|
||||
output_token_ids, [self.target_token_id1, self.neighbour_token_id2]
|
||||
)
|
||||
assert (self.target_token_id2 in output_token_ids) or (
|
||||
self.neighbour_token_id2 in output_token_ids
|
||||
)
|
||||
|
||||
def _generate(self,
|
||||
llm: LLM,
|
||||
bad_words: Optional[list[str]] = None) -> list[int]:
|
||||
def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]:
|
||||
return _generate(
|
||||
llm=llm,
|
||||
prompt=self.PROMPT,
|
||||
@@ -187,8 +192,5 @@ class TestTwoTokenBadWord:
|
||||
|
||||
return False
|
||||
|
||||
def _encode(self,
|
||||
prompt: str,
|
||||
add_special_tokens: bool = True) -> list[int]:
|
||||
return self.tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens).input_ids
|
||||
def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
|
||||
return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
|
||||
|
||||
Reference in New Issue
Block a user