[Optimization] Cache sampled token ids in model runner (#20291)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-07-01 11:01:31 -07:00
committed by GitHub
parent 02cabff207
commit 7f280d69c9
5 changed files with 91 additions and 45 deletions

View File

@@ -88,6 +88,8 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: list[bool]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]

View File

@@ -55,6 +55,7 @@ class Scheduler(SchedulerInterface):
self.lora_config = vllm_config.lora_config
self.kv_cache_config = kv_cache_config
self.kv_events_config = vllm_config.kv_events_config
self.parallel_config = vllm_config.parallel_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
@@ -87,7 +88,7 @@ class Scheduler(SchedulerInterface):
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
vllm_config.parallel_config.data_parallel_rank,
self.parallel_config.data_parallel_rank,
)
num_gpu_blocks = self.cache_config.num_gpu_blocks
@@ -159,6 +160,7 @@ class Scheduler(SchedulerInterface):
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
@@ -214,7 +216,7 @@ class Scheduler(SchedulerInterface):
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
@@ -624,9 +626,15 @@ class Scheduler(SchedulerInterface):
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do