[V1][Spec Decode] Ngram Spec Decode (#12193)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@@ -4,10 +4,12 @@ from typing import List, Optional
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
model: str = "facebook/opt-125m",
|
||||
@@ -38,6 +40,7 @@ def create_scheduler(
|
||||
return Scheduler(scheduler_config,
|
||||
model_config,
|
||||
cache_config,
|
||||
speculative_config=None,
|
||||
lora_config=None,
|
||||
log_stats=True)
|
||||
|
||||
@@ -46,8 +49,12 @@ def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[List[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
):
|
||||
sampling_params = SamplingParams()
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
@@ -64,7 +71,7 @@ def create_requests(
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=0,
|
||||
)
|
||||
requests.append(request)
|
||||
@@ -195,7 +202,7 @@ def test_schedule_partial_requests():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[0] * len(requests),
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
@@ -215,6 +222,189 @@ def test_schedule_partial_requests():
|
||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
||||
|
||||
|
||||
def test_stop_via_update_from_output():
|
||||
"""Test stopping behavior through update_from_output"""
|
||||
scheduler = create_scheduler()
|
||||
|
||||
# Test case 1: Stop on EOS token
|
||||
requests = create_requests(num_requests=2, max_tokens=10)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||
[10,
|
||||
11]], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped, second continues
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
|
||||
assert list(requests[1].output_token_ids) == [10, 11]
|
||||
|
||||
# Test case 2: Stop on custom stop token
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=2,
|
||||
max_tokens=10,
|
||||
stop_token_ids=[42, 43])
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped on custom token
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||
assert requests[0].stop_reason == 42
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [10, 42]
|
||||
assert list(requests[1].output_token_ids) == [13, 14]
|
||||
|
||||
# Test case 3: Stop on max tokens
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=2, max_tokens=2)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped due to length
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [10, 11
|
||||
] # Truncated to max_tokens
|
||||
assert list(requests[1].output_token_ids) == [13]
|
||||
|
||||
# Test case 4: Ignore EOS flag
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=1, max_tokens=10)
|
||||
requests[0].sampling_params.ignore_eos = True
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
scheduler.scheduled_req_ids.add(requests[0].request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify request continues past EOS
|
||||
assert len(scheduler.running) == 1
|
||||
assert not requests[0].is_finished()
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||
|
||||
|
||||
def test_schedule_concurrent_batches():
|
||||
scheduler = create_scheduler(
|
||||
max_num_batched_tokens=1024,
|
||||
@@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[0],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
@@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[0],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user