[BugFix] Fix async scheduling for pooling models (#31584)

Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2025-12-31 14:48:51 -08:00
committed by GitHub
parent d8da76f3b7
commit 6c2cfb62ff
7 changed files with 132 additions and 93 deletions

View File

@@ -104,8 +104,11 @@ try:
scheduler_output, intermediate_tensors scheduler_output, intermediate_tensors
) )
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
output = scheduler_output, grammar_output, output return scheduler_output, grammar_output, output
elif not get_pp_group().is_last_rank:
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
if not get_pp_group().is_last_rank:
# Case where there are no scheduled requests # Case where there are no scheduled requests
# but may still be finished requests. # but may still be finished requests.
assert not output or not output.req_ids assert not output or not output.req_ids

View File

@@ -151,21 +151,23 @@ class ModelRunnerOutput:
# num_generated_tokens is the number of tokens # num_generated_tokens is the number of tokens
# generated in the current step. It can be different for # generated in the current step. It can be different for
# each request due to speculative/jump decoding. # each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]] sampled_token_ids: list[list[int]] = field(default_factory=list)
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs] # [num_reqs]
logprobs: LogprobsLists | None logprobs: LogprobsLists | None = None
# req_id -> (token_ids, logprobs, ranks) # req_id -> (token_ids, logprobs, ranks)
# [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs]
# [prompt_len] # [prompt_len]
prompt_logprobs_dict: dict[str, LogprobsTensors | None] prompt_logprobs_dict: dict[str, LogprobsTensors | None] = field(
default_factory=dict
)
# [num_reqs, hidden_size] # [num_reqs, hidden_size]
pooler_output: list[torch.Tensor | None] pooler_output: list[torch.Tensor | None] | None = None
kv_connector_output: KVConnectorOutput | None = None kv_connector_output: KVConnectorOutput | None = None
@@ -225,21 +227,8 @@ def make_empty_encoder_model_runner_output(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_id_to_index, req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output, pooler_output=pooler_output,
kv_connector_output=None,
ec_connector_output=None,
num_nans_in_logits=None,
) )
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={})
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None,
)

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np
import torch import torch
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@@ -91,36 +92,27 @@ class PoolingMetadata:
def build_pooling_cursor( def build_pooling_cursor(
self, self,
num_scheduled_tokens: list[int], num_scheduled_tokens_np: np.ndarray,
seq_lens_cpu: torch.Tensor, seq_lens_cpu: torch.Tensor,
device: torch.device, device: torch.device,
): ):
self.pooling_cursor = build_pooling_cursor( n_seq = len(num_scheduled_tokens_np)
num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device prompt_lens = self.prompt_lens
assert len(prompt_lens) == n_seq
index = list(range(n_seq))
num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np)
cumsum = torch.zeros(
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
)
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
cumsum = cumsum.to(device, non_blocking=True)
self.pooling_cursor = PoolingCursor(
index=index,
first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens,
seq_lens_cpu=seq_lens_cpu,
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
) )
def build_pooling_cursor(
num_scheduled_tokens: list[int],
seq_lens_cpu: torch.Tensor,
prompt_lens: torch.Tensor,
device: torch.device,
):
assert len(prompt_lens) == len(num_scheduled_tokens)
n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
num_scheduled_tokens_cpu = torch.tensor(num_scheduled_tokens, device="cpu")
cumsum = torch.zeros(
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
)
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
cumsum = cumsum.to(device, non_blocking=True)
return PoolingCursor(
index=index,
first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens,
seq_lens_cpu=seq_lens_cpu,
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
)

View File

@@ -968,11 +968,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Only for compatibility with the existing model runner and scheduler. # Only for compatibility with the existing model runner and scheduler.
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None, # type: ignore sampled_token_ids=None, # type: ignore
logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
) )
async_output = AsyncOutput( async_output = AsyncOutput(
model_runner_output=model_runner_output, model_runner_output=model_runner_output,

View File

@@ -254,6 +254,50 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return output return output
class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
raw_pooler_output: PoolerOutput,
finished_mask: list[bool],
async_output_copy_stream: torch.cuda.Stream,
):
self._model_runner_output = model_runner_output
self._finished_mask = finished_mask
# Event on the copy stream so we can synchronize the non-blocking copy.
self.async_copy_ready_event = torch.Event()
# Keep a reference to the device tensors to avoid them being
# deallocated until we finish copying it to the host.
self._raw_pooler_output = raw_pooler_output
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self._raw_pooler_output_cpu = json_map_leaves(
lambda x: None if x is None else x.to("cpu", non_blocking=True),
self._raw_pooler_output,
)
self.async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput:
"""Copy the device tensors to the host and return a ModelRunnerOutput.
This function blocks until the copy is finished.
"""
self.async_copy_ready_event.synchronize()
# Release the device tensors once the copy has completed.
del self._raw_pooler_output
self._model_runner_output.pooler_output = [
out if include else None
for out, include in zip(self._raw_pooler_output_cpu, self._finished_mask)
]
return self._model_runner_output
class ExecuteModelState(NamedTuple): class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and """Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None.""" sample_tokens(), after execute_model() returns None."""
@@ -2476,17 +2520,19 @@ class GPUModelRunner(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
num_scheduled_tokens: int, num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray, num_scheduled_tokens_np: np.ndarray,
) -> ModelRunnerOutput: kv_connector_output: KVConnectorOutput | None,
assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
num_reqs = self.input_batch.num_reqs
assert num_reqs == len(self.input_batch.pooling_params), (
"Either all or none of the requests in a batch must be pooling request" "Either all or none of the requests in a batch must be pooling request"
) )
hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[:num_scheduled_tokens]
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor( pooling_metadata.build_pooling_cursor(
num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device
) )
model = cast(VllmModelForPooling, self.model) model = cast(VllmModelForPooling, self.model)
@@ -2494,27 +2540,41 @@ class GPUModelRunner(
hidden_states=hidden_states, hidden_states=hidden_states,
pooling_metadata=pooling_metadata, pooling_metadata=pooling_metadata,
) )
finished_mask = [
seq_len == prompt_len
for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens)
]
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids.copy(),
req_id_to_index=self.input_batch.req_id_to_index.copy(),
kv_connector_output=kv_connector_output,
)
if raw_pooler_output is None or not any(finished_mask):
model_runner_output.pooler_output = [None] * num_reqs
return model_runner_output
if self.use_async_scheduling:
return AsyncGPUPoolingModelRunnerOutput(
model_runner_output=model_runner_output,
raw_pooler_output=raw_pooler_output,
finished_mask=finished_mask,
async_output_copy_stream=self.async_output_copy_stream,
)
raw_pooler_output = json_map_leaves( raw_pooler_output = json_map_leaves(
lambda x: x.to("cpu", non_blocking=True) if x is not None else x, lambda x: None if x is None else x.to("cpu", non_blocking=True),
raw_pooler_output, raw_pooler_output,
) )
self._sync_device() self._sync_device()
pooler_output: list[torch.Tensor | None] = [] model_runner_output.pooler_output = [
for raw_output, seq_len, prompt_len in zip( out if include else None
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens for out, include in zip(raw_pooler_output, finished_mask)
): ]
output = raw_output if seq_len == prompt_len else None return model_runner_output
pooler_output.append(output)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
)
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
@@ -3036,7 +3096,7 @@ class GPUModelRunner(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | IntermediateTensors | None: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None:
if self.execute_model_state is not None: if self.execute_model_state is not None:
raise RuntimeError( raise RuntimeError(
"State error: sample_tokens() must be called " "State error: sample_tokens() must be called "
@@ -3244,11 +3304,12 @@ class GPUModelRunner(
if self.is_pooling_model: if self.is_pooling_model:
# Return the pooling output. # Return the pooling output.
output = self._pool( return self._pool(
hidden_states, num_scheduled_tokens, num_scheduled_tokens_np hidden_states,
num_scheduled_tokens,
num_scheduled_tokens_np,
kv_connector_output,
) )
output.kv_connector_output = kv_connector_output
return output
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
@@ -3437,7 +3498,6 @@ class GPUModelRunner(
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output, kv_connector_output=kv_connector_output,
ec_connector_output=ec_connector_output ec_connector_output=ec_connector_output
if self.supports_mm_inputs if self.supports_mm_inputs
@@ -4508,17 +4568,14 @@ class GPUModelRunner(
max_num_reqs = self.scheduler_config.max_num_seqs max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs) num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_np = np.full(num_reqs, min_tokens_per_req)
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_np[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens assert np.sum(num_scheduled_tokens_np) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs assert len(num_scheduled_tokens_np) == num_reqs
req_num_tokens = num_tokens // num_reqs req_num_tokens = num_tokens // num_reqs
dummy_prompt_lens = torch.tensor( dummy_prompt_lens = torch.from_numpy(num_scheduled_tokens_np)
num_scheduled_tokens_list,
device="cpu",
)
dummy_token_ids = torch.zeros( dummy_token_ids = torch.zeros(
(num_reqs, req_num_tokens), dtype=torch.int32, device=self.device (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device
) )
@@ -4537,7 +4594,7 @@ class GPUModelRunner(
) )
dummy_metadata.build_pooling_cursor( dummy_metadata.build_pooling_cursor(
num_scheduled_tokens_list, num_scheduled_tokens_np,
seq_lens_cpu=dummy_prompt_lens, seq_lens_cpu=dummy_prompt_lens,
device=hidden_states.device, device=hidden_states.device,
) )

View File

@@ -575,7 +575,7 @@ class Worker(WorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, scheduler_output: "SchedulerOutput" self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
intermediate_tensors = None intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0 forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -624,7 +624,9 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors scheduler_output, intermediate_tensors
) )
if isinstance(output, ModelRunnerOutput | NoneType): if isinstance(
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
):
return output return output
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)

View File

@@ -124,7 +124,7 @@ class WorkerBase:
def execute_model( def execute_model(
self, scheduler_output: SchedulerOutput self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | None: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
"""If this method returns None, sample_tokens should be called immediately after """If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput. to obtain the ModelRunnerOutput.
@@ -362,7 +362,7 @@ class WorkerWrapperBase:
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
*args, *args,
**kwargs, **kwargs,
) -> ModelRunnerOutput | None: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output) self._apply_mm_cache(scheduler_output)
assert self.worker is not None assert self.worker is not None