[v1] Support allowed_token_ids in v1 Sampler (#13210)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
@@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
|||||||
output_token_ids=[],
|
output_token_ids=[],
|
||||||
min_tokens={},
|
min_tokens={},
|
||||||
logit_bias=[None] * batch_size,
|
logit_bias=[None] * batch_size,
|
||||||
|
allowed_token_ids_mask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,26 @@ def _create_logit_bias(
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _create_allowed_token_ids(
|
||||||
|
batch_size: int,
|
||||||
|
vocab_size: int,
|
||||||
|
num_allowed_token_ids: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
mask: Optional[torch.Tensor] = None
|
||||||
|
for i in range(batch_size):
|
||||||
|
if i % 2 == 1:
|
||||||
|
continue
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.zeros((batch_size, vocab_size),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device)
|
||||||
|
start = min(i, vocab_size - 1)
|
||||||
|
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||||
|
mask[i, start:end] = True
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def _create_default_sampling_metadata(
|
def _create_default_sampling_metadata(
|
||||||
num_output_tokens: int,
|
num_output_tokens: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
@@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
|
|||||||
no_penalties=True,
|
no_penalties=True,
|
||||||
min_tokens={},
|
min_tokens={},
|
||||||
logit_bias=[None] * batch_size,
|
logit_bias=[None] * batch_size,
|
||||||
|
allowed_token_ids_mask=None,
|
||||||
)
|
)
|
||||||
return fake_sampling_metadata
|
return fake_sampling_metadata
|
||||||
|
|
||||||
@@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
|||||||
sampling_metadata.frequency_penalties = _create_penalty_tensor(
|
sampling_metadata.frequency_penalties = _create_penalty_tensor(
|
||||||
batch_size, frequency_penalty, torch.device(device))
|
batch_size, frequency_penalty, torch.device(device))
|
||||||
output_token_ids, sorted_token_ids_in_output = \
|
output_token_ids, sorted_token_ids_in_output = \
|
||||||
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
|
_create_weighted_output_token_list(
|
||||||
|
batch_size,
|
||||||
|
VOCAB_SIZE,
|
||||||
|
)
|
||||||
sampling_metadata.output_token_ids = output_token_ids
|
sampling_metadata.output_token_ids = output_token_ids
|
||||||
sampling_metadata.no_penalties = False
|
sampling_metadata.no_penalties = False
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
@@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
|||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||||
penalized_token_id = logits[batch_idx].argmin().item()
|
penalized_token_id = logits[batch_idx].argmin().item()
|
||||||
distinct_sorted_token_ids_in_output = \
|
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
|
||||||
sorted_token_ids_in_output[batch_idx]
|
batch_idx]
|
||||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||||
len(distinct_sorted_token_ids_in_output) - 1]
|
len(distinct_sorted_token_ids_in_output) - 1]
|
||||||
if frequency_penalty > 0:
|
if frequency_penalty > 0:
|
||||||
@@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
|||||||
# non-penalized token ID is not present in the output, while the
|
# non-penalized token ID is not present in the output, while the
|
||||||
# most penalized token is the one that occurs most frequently in
|
# most penalized token is the one that occurs most frequently in
|
||||||
# the output.
|
# the output.
|
||||||
assert non_penalized_token_id \
|
assert (non_penalized_token_id
|
||||||
not in distinct_sorted_token_ids_in_output
|
not in distinct_sorted_token_ids_in_output)
|
||||||
assert penalized_token_id == most_frequent_token_id
|
assert penalized_token_id == most_frequent_token_id
|
||||||
elif frequency_penalty < 0:
|
elif frequency_penalty < 0:
|
||||||
# If `frequency_penalty` is set to < 0, it indicates
|
# If `frequency_penalty` is set to < 0, it indicates
|
||||||
@@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
|||||||
# in the output, while the penalized token ID is one that has not
|
# in the output, while the penalized token ID is one that has not
|
||||||
# yet appeared.
|
# yet appeared.
|
||||||
assert non_penalized_token_id == most_frequent_token_id
|
assert non_penalized_token_id == most_frequent_token_id
|
||||||
assert penalized_token_id \
|
assert penalized_token_id not in distinct_sorted_token_ids_in_output
|
||||||
not in distinct_sorted_token_ids_in_output
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
|||||||
# If `repetition_penalty` > 1.0, verify that the non-penalized
|
# If `repetition_penalty` > 1.0, verify that the non-penalized
|
||||||
# token ID has not been seen before, while the penalized token ID
|
# token ID has not been seen before, while the penalized token ID
|
||||||
# exists either in the prompt or the output.
|
# exists either in the prompt or the output.
|
||||||
assert (non_penalized_token_id not in prompt_tokens and \
|
assert (non_penalized_token_id not in prompt_tokens
|
||||||
non_penalized_token_id not in output_tokens)
|
and non_penalized_token_id not in output_tokens)
|
||||||
assert (penalized_token_id in prompt_tokens or \
|
assert (penalized_token_id in prompt_tokens
|
||||||
penalized_token_id in output_tokens)
|
or penalized_token_id in output_tokens)
|
||||||
elif repetition_penalty < 1.0:
|
elif repetition_penalty < 1.0:
|
||||||
# If `repetition_penalty` < 1.0, verify that the penalized
|
# If `repetition_penalty` < 1.0, verify that the penalized
|
||||||
# token ID has not been seen before, while the non-penalized
|
# token ID has not been seen before, while the non-penalized
|
||||||
# token ID exists either in the prompt or the output.
|
# token ID exists either in the prompt or the output.
|
||||||
assert (penalized_token_id not in prompt_tokens and \
|
assert (penalized_token_id not in prompt_tokens
|
||||||
penalized_token_id not in output_tokens)
|
and penalized_token_id not in output_tokens)
|
||||||
assert (non_penalized_token_id in prompt_tokens or \
|
assert (non_penalized_token_id in prompt_tokens
|
||||||
non_penalized_token_id in output_tokens)
|
or non_penalized_token_id in output_tokens)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
|
|||||||
1e-2)
|
1e-2)
|
||||||
else:
|
else:
|
||||||
assert logits_for_req[token_id] == pytest.approx(1e-2)
|
assert logits_for_req[token_id] == pytest.approx(1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||||
|
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
|
||||||
|
def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||||
|
num_allowed_token_ids: int):
|
||||||
|
"""
|
||||||
|
Test to verify that when the repetition penalty is enabled, tokens
|
||||||
|
are penalized based on their presence in the prompt or the existing
|
||||||
|
output.
|
||||||
|
"""
|
||||||
|
torch.set_default_device(device)
|
||||||
|
# Create fake logits where each token is assigned the same
|
||||||
|
# logit value.
|
||||||
|
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||||
|
sampling_metadata = _create_default_sampling_metadata(
|
||||||
|
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||||
|
mask = _create_allowed_token_ids(
|
||||||
|
batch_size=batch_size,
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
num_allowed_token_ids=num_allowed_token_ids,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
sampling_metadata.allowed_token_ids_mask = mask
|
||||||
|
sampler = Sampler()
|
||||||
|
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
|
||||||
|
logits = logits.cpu()
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
logits_for_req = logits[batch_idx]
|
||||||
|
if batch_idx % 2 == 1:
|
||||||
|
assert torch.all(logits_for_req != -float("inf"))
|
||||||
|
continue
|
||||||
|
for token_id in range(VOCAB_SIZE):
|
||||||
|
start = min(batch_idx, VOCAB_SIZE - 1)
|
||||||
|
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
|
||||||
|
if token_id >= start and token_id < end:
|
||||||
|
assert logits_for_req[token_id] == -float(
|
||||||
|
"inf"), f"{batch_idx}, {token_id}"
|
||||||
|
else:
|
||||||
|
assert logits_for_req[token_id] != -float("inf")
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
|
|||||||
temperature = [0.0 for _ in range(num_reqs)]
|
temperature = [0.0 for _ in range(num_reqs)]
|
||||||
min_tokens = {}
|
min_tokens = {}
|
||||||
logit_bias = [None] * num_reqs
|
logit_bias = [None] * num_reqs
|
||||||
|
allowed_token_ids_mask = torch.zeros(num_reqs,
|
||||||
|
VOCAB_SIZE,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device)
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req.req_id not in req_ids_retained:
|
if req.req_id not in req_ids_retained:
|
||||||
continue
|
continue
|
||||||
@@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
|
|||||||
req.sampling_params.min_tokens,
|
req.sampling_params.min_tokens,
|
||||||
req.sampling_params.all_stop_token_ids)
|
req.sampling_params.all_stop_token_ids)
|
||||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||||
|
if req.sampling_params.allowed_token_ids:
|
||||||
|
allowed_token_ids_mask[index_in_input_batch][
|
||||||
|
req.sampling_params.allowed_token_ids] = True
|
||||||
|
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||||
device=device),
|
device=device),
|
||||||
@@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
|
|||||||
and all(x == 0 for x in frequency_penalties)
|
and all(x == 0 for x in frequency_penalties)
|
||||||
and all(x == 1 for x in repetition_penalties)),
|
and all(x == 1 for x in repetition_penalties)),
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -242,3 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|||||||
assert expected_sampling_metadata.no_penalties == \
|
assert expected_sampling_metadata.no_penalties == \
|
||||||
sampling_metadata.no_penalties
|
sampling_metadata.no_penalties
|
||||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||||
|
if sampling_metadata.allowed_token_ids_mask:
|
||||||
|
assert torch.allclose(
|
||||||
|
expected_sampling_metadata.allowed_token_ids_mask,
|
||||||
|
sampling_metadata.allowed_token_ids_mask)
|
||||||
|
|||||||
@@ -83,6 +83,19 @@ class Processor:
|
|||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
"not enabled!")
|
"not enabled!")
|
||||||
|
|
||||||
|
def _validate_allowed_token_ids(
|
||||||
|
self,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
) -> None:
|
||||||
|
if not isinstance(params, SamplingParams):
|
||||||
|
return
|
||||||
|
if params.allowed_token_ids is None:
|
||||||
|
return
|
||||||
|
if not all(0 <= tid < self.model_config.vocab_size
|
||||||
|
for tid in params.allowed_token_ids):
|
||||||
|
raise ValueError(
|
||||||
|
"allowed_token_ids contains out-of-vocab token id")
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@@ -100,6 +113,7 @@ class Processor:
|
|||||||
|
|
||||||
self._validate_logprobs(params)
|
self._validate_logprobs(params)
|
||||||
self._validate_lora(lora_request)
|
self._validate_lora(lora_request)
|
||||||
|
self._validate_allowed_token_ids(params)
|
||||||
|
|
||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|||||||
@@ -37,3 +37,7 @@ class SamplingMetadata:
|
|||||||
min_tokens: Dict[int, Tuple[int, Set[int]]]
|
min_tokens: Dict[int, Tuple[int, Set[int]]]
|
||||||
|
|
||||||
logit_bias: List[Optional[Dict[int, float]]]
|
logit_bias: List[Optional[Dict[int, float]]]
|
||||||
|
|
||||||
|
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||||
|
# vocab size).
|
||||||
|
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
# Use float32 for the logits.
|
# Use float32 for the logits.
|
||||||
logits = logits.to(torch.float32)
|
logits = logits.to(torch.float32)
|
||||||
|
# Apply allowed token ids.
|
||||||
|
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||||
# Apply logits bias.
|
# Apply logits bias.
|
||||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||||
@@ -184,11 +186,13 @@ class Sampler(nn.Module):
|
|||||||
if not sampling_metadata.no_penalties:
|
if not sampling_metadata.no_penalties:
|
||||||
assert sampling_metadata.prompt_token_ids is not None
|
assert sampling_metadata.prompt_token_ids is not None
|
||||||
logits = apply_all_penalties(
|
logits = apply_all_penalties(
|
||||||
logits, sampling_metadata.prompt_token_ids,
|
logits,
|
||||||
|
sampling_metadata.prompt_token_ids,
|
||||||
sampling_metadata.presence_penalties,
|
sampling_metadata.presence_penalties,
|
||||||
sampling_metadata.frequency_penalties,
|
sampling_metadata.frequency_penalties,
|
||||||
sampling_metadata.repetition_penalties,
|
sampling_metadata.repetition_penalties,
|
||||||
sampling_metadata.output_token_ids)
|
sampling_metadata.output_token_ids,
|
||||||
|
)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def apply_min_p(
|
def apply_min_p(
|
||||||
@@ -226,3 +230,13 @@ class Sampler(nn.Module):
|
|||||||
for token_id, bias in logit_bias.items():
|
for token_id, bias in logit_bias.items():
|
||||||
logits[i, token_id] += bias
|
logits[i, token_id] += bias
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def apply_allowed_token_ids(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||||
|
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||||
|
float("-inf"))
|
||||||
|
return logits
|
||||||
|
|||||||
@@ -192,6 +192,9 @@ class InputBatch:
|
|||||||
|
|
||||||
self.logit_bias: List[Optional[Dict[int,
|
self.logit_bias: List[Optional[Dict[int,
|
||||||
float]]] = [None] * max_num_reqs
|
float]]] = [None] * max_num_reqs
|
||||||
|
self.has_allowed_token_ids: Set[str] = set()
|
||||||
|
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
self.req_output_token_ids: List[Optional[List[int]]] = []
|
self.req_output_token_ids: List[Optional[List[int]]] = []
|
||||||
|
|
||||||
@@ -287,6 +290,22 @@ class InputBatch:
|
|||||||
if sampling_params.logit_bias is not None:
|
if sampling_params.logit_bias is not None:
|
||||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||||
|
|
||||||
|
if sampling_params.allowed_token_ids:
|
||||||
|
self.has_allowed_token_ids.add(req_id)
|
||||||
|
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||||
|
# Lazy allocation for this tensor, which can be large.
|
||||||
|
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
||||||
|
self.vocab_size,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device)
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||||
|
self.max_num_reqs,
|
||||||
|
self.vocab_size,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cpu")
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||||
|
sampling_params.allowed_token_ids] = True
|
||||||
|
|
||||||
# Add request lora ID
|
# Add request lora ID
|
||||||
if request.lora_request:
|
if request.lora_request:
|
||||||
lora_id = request.lora_request.lora_int_id
|
lora_id = request.lora_request.lora_int_id
|
||||||
@@ -332,6 +351,9 @@ class InputBatch:
|
|||||||
self.request_lora_mapping[req_index] = 0
|
self.request_lora_mapping[req_index] = 0
|
||||||
|
|
||||||
self.logit_bias[req_index] = None
|
self.logit_bias[req_index] = None
|
||||||
|
self.has_allowed_token_ids.discard(req_id)
|
||||||
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||||
return req_index
|
return req_index
|
||||||
|
|
||||||
def condense(self, empty_req_indices: List[int]) -> None:
|
def condense(self, empty_req_indices: List[int]) -> None:
|
||||||
@@ -400,6 +422,11 @@ class InputBatch:
|
|||||||
|
|
||||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||||
|
|
||||||
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[
|
||||||
|
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||||
|
last_req_index]
|
||||||
|
|
||||||
# Decrement last_req_index since it is now empty.
|
# Decrement last_req_index since it is now empty.
|
||||||
last_req_index -= 1
|
last_req_index -= 1
|
||||||
|
|
||||||
@@ -442,6 +469,13 @@ class InputBatch:
|
|||||||
else:
|
else:
|
||||||
prompt_token_ids = None
|
prompt_token_ids = None
|
||||||
|
|
||||||
|
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||||
|
if not self.no_allowed_token_ids:
|
||||||
|
assert self.allowed_token_ids_mask is not None
|
||||||
|
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
||||||
|
self.allowed_token_ids_mask, num_reqs)
|
||||||
|
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
||||||
|
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
all_greedy=self.all_greedy,
|
all_greedy=self.all_greedy,
|
||||||
@@ -460,6 +494,7 @@ class InputBatch:
|
|||||||
min_tokens=self.min_tokens,
|
min_tokens=self.min_tokens,
|
||||||
no_penalties=self.no_penalties,
|
no_penalties=self.no_penalties,
|
||||||
logit_bias=self.logit_bias[:num_reqs],
|
logit_bias=self.logit_bias[:num_reqs],
|
||||||
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sampling_metadata(
|
def get_sampling_metadata(
|
||||||
@@ -550,3 +585,7 @@ class InputBatch:
|
|||||||
@property
|
@property
|
||||||
def no_prompt_logprob(self) -> bool:
|
def no_prompt_logprob(self) -> bool:
|
||||||
return not self.num_prompt_logprobs
|
return not self.num_prompt_logprobs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def no_allowed_token_ids(self) -> bool:
|
||||||
|
return len(self.has_allowed_token_ids) == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user