[Core] Add optional flags to check for repetitive token patterns in engine output (#35451)
Signed-off-by: aykoppol <aykoppol+git@gmail.com>
This commit is contained in:
@@ -107,6 +107,43 @@ class StructuredOutputsParams:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepetitionDetectionParams:
|
||||
"""Parameters for detecting repetitive N-gram patterns in output tokens."""
|
||||
|
||||
max_pattern_size: int = 0
|
||||
"""Maximum size of N-gram pattern to detect for sequence repetition.
|
||||
Set to 0 to disable. Must be used together with min_count."""
|
||||
|
||||
min_pattern_size: int = 0
|
||||
"""Minimum N-gram pattern size to check for sequence repetition.
|
||||
If set to 0, it defaults to 1.
|
||||
Must be <= max_pattern_size."""
|
||||
|
||||
min_count: int = 0
|
||||
"""Minimum number of times an N-gram pattern must repeat to trigger
|
||||
detection. Must be >= 2. Example: 3 for detecting a phrase repeated
|
||||
3 times. Must be used together with max_pattern_size."""
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.max_pattern_size < 0
|
||||
or self.min_pattern_size < 0
|
||||
or self.min_pattern_size > self.max_pattern_size
|
||||
):
|
||||
raise ValueError(
|
||||
"max_pattern_size, min_pattern_size must be >=0, "
|
||||
"with min_pattern_size <= max_pattern_size. "
|
||||
"Set both to 0 to disable repetitive pattern detection."
|
||||
)
|
||||
if self.max_pattern_size > 0 and self.min_count < 2:
|
||||
raise ValueError(
|
||||
"min_count must be >= 2 to detect repetitive patterns "
|
||||
"in engine output. If you do not wish to detect repetitive "
|
||||
"patterns, set max_pattern_size to 0."
|
||||
)
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
@@ -246,6 +283,14 @@ class SamplingParams(
|
||||
|
||||
skip_reading_prefix_cache: bool | None = None
|
||||
|
||||
repetition_detection: RepetitionDetectionParams | None = None
|
||||
"""Parameters for detecting repetitive N-gram patterns in output tokens.
|
||||
If such repetition is detected, generation will be ended early. LLMs can
|
||||
sometimes generate repetitive, unhelpful token patterns, stopping only
|
||||
when they hit the maximum output length (e.g. 'abcdabcdabcd...' or
|
||||
'\\emoji \\emoji \\emoji ...'). This feature can detect such behavior
|
||||
and terminate early, saving time and tokens."""
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: int | None = 1,
|
||||
@@ -275,6 +320,7 @@ class SamplingParams(
|
||||
allowed_token_ids: list[int] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
skip_clone: bool = False,
|
||||
repetition_detection: RepetitionDetectionParams | None = None,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
# Convert token_id to integer
|
||||
@@ -314,6 +360,7 @@ class SamplingParams(
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
extra_args=extra_args,
|
||||
skip_clone=skip_clone,
|
||||
repetition_detection=repetition_detection,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user