[Bugfix] Validate custom logits processor xargs for online serving (#27560)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -131,10 +131,34 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of overriding the wrapper class `__init__()` in order to utilize
|
||||
info about the device type"""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
super().__init__(vllm_config, device, is_pin_memory)
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
|
||||
window_size = params.extra_args and params.extra_args.get("window_size", 100)
|
||||
whitelist_token_ids = params.extra_args and params.extra_args.get(
|
||||
"whitelist_token_ids", None
|
||||
)
|
||||
# if ngram_size is not provided, skip validation because the processor
|
||||
# will not be used.
|
||||
if ngram_size is None:
|
||||
return None
|
||||
|
||||
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||
raise ValueError(
|
||||
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
|
||||
)
|
||||
if not isinstance(window_size, int) or window_size <= 0:
|
||||
raise ValueError(
|
||||
"`window_size` has to be a strictly positive integer, "
|
||||
f"got {window_size}."
|
||||
)
|
||||
if whitelist_token_ids is not None and not isinstance(
|
||||
whitelist_token_ids, Iterable
|
||||
):
|
||||
raise ValueError(
|
||||
"`whitelist_token_ids` has to be a sequence of integers, "
|
||||
f"got {whitelist_token_ids}."
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return True
|
||||
@@ -150,26 +174,8 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
)
|
||||
if ngram_size is None:
|
||||
return None
|
||||
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||
raise ValueError(
|
||||
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
|
||||
)
|
||||
if not isinstance(window_size, int) or window_size <= 0:
|
||||
raise ValueError(
|
||||
"`window_size` has to be a strictly positive integer, "
|
||||
f"got {window_size}."
|
||||
)
|
||||
if whitelist_token_ids is not None and not isinstance(
|
||||
whitelist_token_ids, Iterable
|
||||
):
|
||||
raise ValueError(
|
||||
"`whitelist_token_ids` has to be a set of integers, "
|
||||
f"got {whitelist_token_ids}."
|
||||
)
|
||||
else:
|
||||
whitelist_token_ids = (
|
||||
set(whitelist_token_ids) if whitelist_token_ids else None
|
||||
)
|
||||
|
||||
whitelist_token_ids = set(whitelist_token_ids) if whitelist_token_ids else None
|
||||
return NoRepeatNGramLogitsProcessor(
|
||||
ngram_size=ngram_size,
|
||||
window_size=window_size,
|
||||
|
||||
Reference in New Issue
Block a user