[Refactor] Simplify BOS/EOS token handling (#34435)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-13 10:18:24 +08:00
committed by GitHub
parent 04ea31baab
commit ea5ff3a1f6
29 changed files with 123 additions and 123 deletions

View File

@@ -469,8 +469,7 @@ def test_stop_via_update_from_output():
# Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests = create_requests(num_requests=1, max_tokens=10, ignore_eos=True)
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
@@ -515,12 +514,12 @@ def test_check_stop_min_tokens():
max_tokens=20,
min_tokens=5,
)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
request = Request(
request_id="0",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Simulate having generated 3 output tokens (less than min_tokens=5)
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
@@ -551,12 +550,12 @@ def test_check_stop_min_tokens():
max_tokens=20,
min_tokens=0,
)
sampling_params_no_min.update_from_generation_config({}, EOS_TOKEN_ID)
request_no_min = Request(
request_id="1",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_no_min,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
@@ -571,12 +570,12 @@ def test_check_stop_min_tokens():
min_tokens=5,
stop_token_ids=[42],
)
sampling_params_stop.update_from_generation_config({}, EOS_TOKEN_ID)
request_stop = Request(
request_id="2",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_stop,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Only 3 output tokens, less than min_tokens=5, but has stop token
request_stop.append_output_token_ids([10, 11, 42])
@@ -1877,6 +1876,7 @@ def create_requests_with_priority(
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs,
)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
requests = []
if mm_hashes_list is not None:
@@ -1938,7 +1938,6 @@ def create_requests_with_priority(
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
block_hasher=block_hasher,
@@ -2429,13 +2428,13 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
max_tokens=16,
structured_outputs=structured_outputs_params,
)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
scheduler.add_request(request)
output = scheduler.schedule()