Files
vllm/tests/v1/core/test_repetition_detection.py

291 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.sampling_params import RepetitionDetectionParams, SamplingParams
from vllm.v1.core.sched.utils import check_sequence_repetition, check_stop
from vllm.v1.request import Request, RequestStatus
pytestmark = pytest.mark.cpu_test
# ============================================================================
# UNIT TESTS - check_sequence_repetition function
# ============================================================================
class TestCheckSequenceRepetition:
"""Unit tests for the check_sequence_repetition function"""
def test_simple_repetition_detected(self):
"""Test detection of simple repetitive patterns"""
token_ids = [1, 2, 3, 1, 2, 3, 1, 2, 3]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert check_sequence_repetition(token_ids, params)
def test_repetition_below_min_count(self):
"""Test that pattern below min_count is not detected"""
token_ids = [1, 2, 3, 1, 2, 3]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert not check_sequence_repetition(token_ids, params)
def test_two_token_pattern(self):
"""Test detection of 2-token patterns"""
token_ids = [1, 2, 1, 2, 1, 2, 1, 2]
params = RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=4,
)
assert check_sequence_repetition(token_ids, params)
def test_no_repetition_varied_sequence(self):
"""Test that non-repetitive sequences are not flagged"""
token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
params = RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=2,
)
assert not check_sequence_repetition(token_ids, params)
def test_partial_repetition_not_detected(self):
"""Test that incomplete repetitions are not detected"""
token_ids = [1, 2, 3, 1, 2, 3, 1, 2, 4]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert not check_sequence_repetition(token_ids, params)
def test_empty_token_list(self):
"""Test with empty token list"""
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=2,
)
assert not check_sequence_repetition([], params)
def test_detection_disabled_max_size_zero(self):
"""Test that zero max_pattern_size disables detection"""
token_ids = [1, 2, 1, 2, 1, 2]
params = RepetitionDetectionParams()
assert not check_sequence_repetition(token_ids, params)
def test_invalid_min_count(self):
"""Test that min_count < 2 returns False"""
token_ids = [1, 2, 1, 2]
params = RepetitionDetectionParams()
assert not check_sequence_repetition(token_ids, params)
def test_repetition_at_end_of_sequence(self):
"""Test detection when repetition occurs at the end"""
token_ids = [1, 2, 3, 4, 5, 6, 5, 6, 5, 6]
params = RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=2,
min_count=3,
)
assert check_sequence_repetition(token_ids, params)
def test_large_pattern_many_repetitions(self):
"""Test large pattern repeated many times"""
token_ids = [1, 2, 3, 4, 5, 6, 7, 8] * 5
params = RepetitionDetectionParams(
max_pattern_size=10,
min_pattern_size=2,
min_count=3,
)
assert check_sequence_repetition(token_ids, params)
# ============================================================================
# INTEGRATION TESTS - check_stop with repetition detection
# ============================================================================
class TestRepetitionDetectionIntegration:
"""Integration tests for repetition detection in check_stop"""
def test_basic_repetition_stops_generation(self):
"""Test that repetition is detected and stops generation"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_REPETITION
assert request.stop_reason == "repetition_detected"
def test_detection_disabled_no_stop(self):
"""Test that disabled detection doesn't stop generation"""
params = SamplingParams(
max_tokens=100,
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)
def test_repetition_respects_min_tokens(self):
"""Test that repetition detection respects min_tokens"""
params = SamplingParams(
min_tokens=10,
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)
def test_no_repetition_continues_generation(self):
"""Test that non-repetitive tokens don't stop generation"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 30, 40, 50, 60])
assert not check_stop(request, max_model_len=1024)
def test_pattern_at_size_boundary(self):
"""Test detection at exact pattern size boundary"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=3,
min_pattern_size=3,
min_count=2,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 30, 10, 20, 30])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_REPETITION
def test_multiple_pattern_sizes_checked(self):
"""Test that function checks pattern sizes in range"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([7, 8, 9, 10, 7, 8, 9, 10, 7, 8, 9, 10])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_REPETITION
def test_eos_takes_precedence_over_repetition(self):
"""Test that EOS token stops before repetition check"""
params = SamplingParams(
max_tokens=100,
stop_token_ids=[999],
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1, 2, 3],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 999])
assert check_stop(request, max_model_len=1024)
assert request.status == RequestStatus.FINISHED_STOPPED
def test_min_pattern_size_filters_small_patterns(self):
"""Test that min_pattern_size filters out smaller patterns"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=3,
min_count=3,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)
def test_high_repetition_threshold(self):
"""Test that high min_count requires many repetitions"""
params = SamplingParams(
max_tokens=100,
repetition_detection=RepetitionDetectionParams(
max_pattern_size=5,
min_pattern_size=2,
min_count=5,
),
)
request = Request(
request_id="test",
prompt_token_ids=[1],
sampling_params=params,
pooling_params=None,
)
request.append_output_token_ids([10, 20, 10, 20, 10, 20])
assert not check_stop(request, max_model_len=1024)