[Core] Add optional flags to check for repetitive token patterns in engine output (#35451)

Signed-off-by: aykoppol <aykoppol+git@gmail.com>
This commit is contained in:
aykoppol
2026-03-02 20:23:25 -08:00
committed by GitHub
parent a0a5178ab4
commit 25e02647c2
7 changed files with 433 additions and 2 deletions

View File

@@ -0,0 +1,290 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.sampling_params import RepetitionDetectionParams, SamplingParams
from vllm.v1.core.sched.utils import check_sequence_repetition, check_stop
from vllm.v1.request import Request, RequestStatus
pytestmark = pytest.mark.cpu_test
# ============================================================================
# UNIT TESTS - check_sequence_repetition function
# ============================================================================
class TestCheckSequenceRepetition:
"""Unit tests for the check_sequence_repetition function"""
def test_simple_repetition_detected(self):
"""Test detection of simple repetitive patterns"""
token_ids = [1, 2, 3, 1, 2, 3, 1, 2, 3]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert check_sequence_repetition(token_ids, params)
def test_repetition_below_min_count(self):
"""Test that pattern below min_count is not detected"""
token_ids = [1, 2, 3, 1, 2, 3]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert not check_sequence_repetition(token_ids, params)
def test_two_token_pattern(self):
"""Test detection of 2-token patterns"""
token_ids = [1, 2, 1, 2, 1, 2, 1, 2]
params = RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=4,
)
assert check_sequence_repetition(token_ids, params)
def test_no_repetition_varied_sequence(self):
"""Test that non-repetitive sequences are not flagged"""
token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
params = RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=2,
)
assert not check_sequence_repetition(token_ids, params)
def test_partial_repetition_not_detected(self):
"""Test that incomplete repetitions are not detected"""
token_ids = [1, 2, 3, 1, 2, 3, 1, 2, 4]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert not check_sequence_repetition(token_ids, params)
def test_empty_token_list(self):
"""Test with empty token list"""
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=2,
)
assert not check_sequence_repetition([], params)
def test_detection_disabled_max_size_zero(self):
"""Test that zero max_pattern_size disables detection"""
token_ids = [1, 2, 1, 2, 1, 2]
params = RepetitionDetectionParams()
assert not check_sequence_repetition(token_ids, params)
def test_invalid_min_count(self):
"""Test that min_count < 2 returns False"""
token_ids = [1, 2, 1, 2]
params = RepetitionDetectionParams()
assert not check_sequence_repetition(token_ids, params)
def test_repetition_at_end_of_sequence(self):
"""Test detection when repetition occurs at the end"""
token_ids = [1, 2, 3, 4, 5, 6, 5, 6, 5, 6]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert check_sequence_repetition(token_ids, params)
def test_large_pattern_many_repetitions(self):
"""Test large pattern repeated many times"""
token_ids = [1, 2, 3, 4, 5, 6, 7, 8] * 5
params = RepetitionDetectionParams(
max_pattern_size=10,
min_pattern_size=2,
min_count=3,
)
assert check_sequence_repetition(token_ids, params)
# ============================================================================
# INTEGRATION TESTS - check_stop with repetition detection
# ============================================================================
class TestRepetitionDetectionIntegration:
"""Integration tests for repetition detection in check_stop"""
def test_basic_repetition_stops_generation(self):
"""Test that repetition is detected and stops generation"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_REPETITION
assert request.stop_reason == "repetition_detected"
def test_detection_disabled_no_stop(self):
"""Test that disabled detection doesn't stop generation"""
params = SamplingParams(
max_tokens=100,
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)
def test_repetition_respects_min_tokens(self):
"""Test that repetition detection respects min_tokens"""
params = SamplingParams(
min_tokens=10,
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)
def test_no_repetition_continues_generation(self):
"""Test that non-repetitive tokens don't stop generation"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 30, 40, 50, 60])
assert not check_stop(request, max_model_len=1024)
def test_pattern_at_size_boundary(self):
"""Test detection at exact pattern size boundary"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=3,
min_count=2,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 30, 10, 20, 30])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_REPETITION
def test_multiple_pattern_sizes_checked(self):
"""Test that function checks pattern sizes in range"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_REPETITION
def test_eos_takes_precedence_over_repetition(self):
"""Test that EOS token stops before repetition check"""
params = SamplingParams(
max_tokens=100,
stop_token_ids=[999],
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 999])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_STOPPED
def test_min_pattern_size_filters_small_patterns(self):
"""Test that min_pattern_size filters out smaller patterns"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=3,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)
def test_high_repetition_threshold(self):
"""Test that high min_count requires many repetitions"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=5,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)

View File

@@ -38,6 +38,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -336,6 +337,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)
repetition_detection: RepetitionDetectionParams | None = Field(
default=None,
description="Parameters for detecting repetitive N-gram patterns "
"in output tokens. If such repetition is detected, generation will "
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
"token patterns, stopping only when they hit the maximum output length "
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
"can detect such behavior and terminate early, saving time and tokens.",
)
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
@@ -499,6 +510,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
repetition_detection=self.repetition_detection,
)
@model_validator(mode="before")

View File

@@ -26,6 +26,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -166,6 +167,16 @@ class CompletionRequest(OpenAIBaseModel):
),
)
repetition_detection: RepetitionDetectionParams | None = Field(
default=None,
description="Parameters for detecting repetitive N-gram patterns "
"in output tokens. If such repetition is detected, generation will "
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
"token patterns, stopping only when they hit the maximum output length "
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
"can detect such behavior and terminate early, saving time and tokens.",
)
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
@@ -310,6 +321,7 @@ class CompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
repetition_detection=self.repetition_detection,
)
@model_validator(mode="before")

View File

@@ -107,6 +107,43 @@ class StructuredOutputsParams:
)
@dataclass
class RepetitionDetectionParams:
"""Parameters for detecting repetitive N-gram patterns in output tokens."""
max_pattern_size: int = 0
"""Maximum size of N-gram pattern to detect for sequence repetition.
Set to 0 to disable. Must be used together with min_count."""
min_pattern_size: int = 0
"""Minimum N-gram pattern size to check for sequence repetition.
If set to 0, it defaults to 1.
Must be <= max_pattern_size."""
min_count: int = 0
"""Minimum number of times an N-gram pattern must repeat to trigger
detection. Must be >= 2. Example: 3 for detecting a phrase repeated
3 times. Must be used together with max_pattern_size."""
def __post_init__(self):
if (
self.max_pattern_size < 0
or self.min_pattern_size < 0
or self.min_pattern_size > self.max_pattern_size
):
raise ValueError(
"max_pattern_size, min_pattern_size must be >=0, "
"with min_pattern_size <= max_pattern_size. "
"Set both to 0 to disable repetitive pattern detection."
)
if self.max_pattern_size > 0 and self.min_count < 2:
raise ValueError(
"min_count must be >= 2 to detect repetitive patterns "
"in engine output. If you do not wish to detect repetitive "
"patterns, set max_pattern_size to 0."
)
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@@ -246,6 +283,14 @@ class SamplingParams(
skip_reading_prefix_cache: bool | None = None
repetition_detection: RepetitionDetectionParams | None = None
"""Parameters for detecting repetitive N-gram patterns in output tokens.
If such repetition is detected, generation will be ended early. LLMs can
sometimes generate repetitive, unhelpful token patterns, stopping only
when they hit the maximum output length (e.g. 'abcdabcdabcd...' or
'\\emoji \\emoji \\emoji ...'). This feature can detect such behavior
and terminate early, saving time and tokens."""
@staticmethod
def from_optional(
n: int | None = 1,
@@ -275,6 +320,7 @@ class SamplingParams(
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
skip_clone: bool = False,
repetition_detection: RepetitionDetectionParams | None = None,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
@@ -314,6 +360,7 @@ class SamplingParams(
allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
skip_clone=skip_clone,
repetition_detection=repetition_detection,
)
def __post_init__(self) -> None:

View File

@@ -1,10 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Sequence
from vllm.sampling_params import RepetitionDetectionParams
from vllm.v1.request import Request, RequestStatus
def _has_repeating_pattern(
token_ids: Sequence[int],
pattern_len: int,
repetition_min_count: int,
) -> bool:
"""Check if the tail of token_ids contains a repeating pattern.
Compares the last pattern_len tokens against the preceding
(repetition_min_count - 1) repetitions of the same length.
"""
for n in range(1, pattern_len + 1):
target_token = token_ids[-n]
for m in range(1, repetition_min_count):
if token_ids[-(pattern_len * m + n)] != target_token:
return False
return True
def check_sequence_repetition(
token_ids: Sequence[int],
params: RepetitionDetectionParams,
) -> bool:
"""Check if a sequence of token IDs has a repetition pattern.
Args:
token_ids: List of token IDs
params: Repetition detection parameters.
Returns:
True if a repetition pattern is found, False otherwise.
"""
max_pattern_size = params.max_pattern_size
min_pattern_size = params.min_pattern_size
min_count = params.min_count
if min_pattern_size <= 0:
min_pattern_size = 1
if max_pattern_size <= 0 or min_count < 2 or min_pattern_size > max_pattern_size:
return False
for pattern_len in range(
min_pattern_size,
max_pattern_size + 1,
):
if pattern_len * min_count > len(token_ids):
return False
if _has_repeating_pattern(token_ids, pattern_len, min_count):
return True
return False
def remove_all(lst: list, items_to_remove: set) -> list:
"""Remove all items from a list that are in the items_to_remove set.
@@ -61,4 +115,16 @@ def check_stop(request: Request, max_model_len: int) -> bool:
):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
repetition_detection = sampling_params.repetition_detection
if repetition_detection is not None and (
check_sequence_repetition(
request.output_token_ids,
repetition_detection,
)
):
request.status = RequestStatus.FINISHED_REPETITION
request.stop_reason = "repetition_detected"
return True
return False

View File

@@ -27,7 +27,7 @@ PauseMode = Literal["abort", "wait", "keep"]
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error", "repetition")
EEP_NOTIFICATION_CALL_ID = -1
@@ -41,7 +41,7 @@ class EEPNotificationType(enum.Enum):
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, abort, or error.
Reason a request finished - stop, length, abort, error, or repetition.
Int rather than Str for more compact serialization.
@@ -50,6 +50,7 @@ class FinishReason(enum.IntEnum):
abort - aborted by client
error - retryable request-level internal error (e.g., KV load failure).
Invariant: always converted to 500 Internal Server Error.
repetition - repetitive token pattern detected (hallucination)
"""
@@ -57,6 +58,7 @@ class FinishReason(enum.IntEnum):
LENGTH = 1
ABORT = 2
ERROR = 3
REPETITION = 4
def __str__(self):
return FINISH_REASON_STRINGS[self.value]

View File

@@ -320,6 +320,7 @@ class RequestStatus(enum.IntEnum):
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
FINISHED_ERROR = enum.auto()
FINISHED_REPETITION = enum.auto()
def __str__(self) -> str:
return self.name
@@ -344,4 +345,5 @@ _FINISHED_REASON_MAP = {
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
RequestStatus.FINISHED_REPETITION: FinishReason.REPETITION,
}