[Core] Add optional flags to check for repetitive token patterns in engine output (#35451)

Signed-off-by: aykoppol <aykoppol+git@gmail.com>
This commit is contained in:
aykoppol
2026-03-02 20:23:25 -08:00
committed by GitHub
parent a0a5178ab4
commit 25e02647c2
7 changed files with 433 additions and 2 deletions

View File

@@ -107,6 +107,43 @@ class StructuredOutputsParams:
)
@dataclass
class RepetitionDetectionParams:
"""Parameters for detecting repetitive N-gram patterns in output tokens."""
max_pattern_size: int = 0
"""Maximum size of N-gram pattern to detect for sequence repetition.
Set to 0 to disable. Must be used together with min_count."""
min_pattern_size: int = 0
"""Minimum N-gram pattern size to check for sequence repetition.
If set to 0, it defaults to 1.
Must be <= max_pattern_size."""
min_count: int = 0
"""Minimum number of times an N-gram pattern must repeat to trigger
detection. Must be >= 2. Example: 3 for detecting a phrase repeated
3 times. Must be used together with max_pattern_size."""
def __post_init__(self):
if (
self.max_pattern_size < 0
or self.min_pattern_size < 0
or self.min_pattern_size > self.max_pattern_size
):
raise ValueError(
"max_pattern_size, min_pattern_size must be >=0, "
"with min_pattern_size <= max_pattern_size. "
"Set both to 0 to disable repetitive pattern detection."
)
if self.max_pattern_size > 0 and self.min_count < 2:
raise ValueError(
"min_count must be >= 2 to detect repetitive patterns "
"in engine output. If you do not wish to detect repetitive "
"patterns, set max_pattern_size to 0."
)
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@@ -246,6 +283,14 @@ class SamplingParams(
skip_reading_prefix_cache: bool | None = None
repetition_detection: RepetitionDetectionParams | None = None
"""Parameters for detecting repetitive N-gram patterns in output tokens.
If such repetition is detected, generation will be ended early. LLMs can
sometimes generate repetitive, unhelpful token patterns, stopping only
when they hit the maximum output length (e.g. 'abcdabcdabcd...' or
'\\emoji \\emoji \\emoji ...'). This feature can detect such behavior
and terminate early, saving time and tokens."""
@staticmethod
def from_optional(
n: int | None = 1,
@@ -275,6 +320,7 @@ class SamplingParams(
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
skip_clone: bool = False,
repetition_detection: RepetitionDetectionParams | None = None,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
@@ -314,6 +360,7 @@ class SamplingParams(
allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
skip_clone=skip_clone,
repetition_detection=repetition_detection,
)
def __post_init__(self) -> None: