[Refactor] Simplify BOS/EOS token handling (#34435)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -469,8 +469,7 @@ def test_stop_via_update_from_output():
|
||||
|
||||
# Test case 4: Ignore EOS flag
|
||||
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||
requests = create_requests(num_requests=1, max_tokens=10)
|
||||
requests[0].sampling_params.ignore_eos = True
|
||||
requests = create_requests(num_requests=1, max_tokens=10, ignore_eos=True)
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
@@ -515,12 +514,12 @@ def test_check_stop_min_tokens():
|
||||
max_tokens=20,
|
||||
min_tokens=5,
|
||||
)
|
||||
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
|
||||
request = Request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[0, 1, 2],
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
)
|
||||
# Simulate having generated 3 output tokens (less than min_tokens=5)
|
||||
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
|
||||
@@ -551,12 +550,12 @@ def test_check_stop_min_tokens():
|
||||
max_tokens=20,
|
||||
min_tokens=0,
|
||||
)
|
||||
sampling_params_no_min.update_from_generation_config({}, EOS_TOKEN_ID)
|
||||
request_no_min = Request(
|
||||
request_id="1",
|
||||
prompt_token_ids=[0, 1, 2],
|
||||
sampling_params=sampling_params_no_min,
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
)
|
||||
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
|
||||
|
||||
@@ -571,12 +570,12 @@ def test_check_stop_min_tokens():
|
||||
min_tokens=5,
|
||||
stop_token_ids=[42],
|
||||
)
|
||||
sampling_params_stop.update_from_generation_config({}, EOS_TOKEN_ID)
|
||||
request_stop = Request(
|
||||
request_id="2",
|
||||
prompt_token_ids=[0, 1, 2],
|
||||
sampling_params=sampling_params_stop,
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
)
|
||||
# Only 3 output tokens, less than min_tokens=5, but has stop token
|
||||
request_stop.append_output_token_ids([10, 11, 42])
|
||||
@@ -1877,6 +1876,7 @@ def create_requests_with_priority(
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
|
||||
requests = []
|
||||
|
||||
if mm_hashes_list is not None:
|
||||
@@ -1938,7 +1938,6 @@ def create_requests_with_priority(
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
mm_features=mm_features if mm_features else None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=arrival_times[i],
|
||||
priority=priorities[i],
|
||||
block_hasher=block_hasher,
|
||||
@@ -2429,13 +2428,13 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
max_tokens=16,
|
||||
structured_outputs=structured_outputs_params,
|
||||
)
|
||||
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)
|
||||
request = Request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[0, 1],
|
||||
mm_features=None,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
)
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
|
||||
Reference in New Issue
Block a user