[Model Runner V2] Support pooling models (#35120)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-27 18:03:01 -08:00
committed by GitHub
parent 405f28d38d
commit 86ac7bcf84
6 changed files with 209 additions and 14 deletions

View File

@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
return self.model_runner_output return self.model_runner_output
class AsyncPoolingOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
pooler_output: torch.Tensor,
is_valid: torch.Tensor | None,
main_stream: torch.cuda.Stream,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
self.model_runner_output = model_runner_output
self.pooler_output = pooler_output
self.is_valid = is_valid
self.copy_event = copy_event
with stream(copy_stream, main_stream):
copy_stream.wait_stream(main_stream)
self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True)
if self.is_valid is not None:
self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True)
else:
self.is_valid_cpu = None
self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
pooler_output = self.pooler_output_cpu.unbind(dim=0)
if self.is_valid_cpu is not None:
is_valid_cpu = self.is_valid_cpu.tolist()
for i, is_valid in enumerate(is_valid_cpu):
if not is_valid:
pooler_output[i] = None
self.model_runner_output.pooler_output = pooler_output
return self.model_runner_output
def async_copy_to_np(x: torch.Tensor) -> np.ndarray: def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
return x.to("cpu", non_blocking=True).numpy() return x.to("cpu", non_blocking=True).numpy()

View File

@@ -499,6 +499,38 @@ def post_update(
) )
@triton.jit
def _post_update_pool_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
query_start_loc_ptr,
):
batch_id = tl.program_id(0)
query_start = tl.load(query_start_loc_ptr + batch_id)
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
query_len = query_end - query_start
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
def post_update_pool(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_pool_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
query_start_loc,
)
@triton.jit @triton.jit
def _expand_idx_mapping_kernel( def _expand_idx_mapping_kernel(
idx_mapping_ptr, idx_mapping_ptr,

View File

@@ -38,13 +38,14 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
get_kv_cache_spec, get_kv_cache_spec,
@@ -66,6 +67,7 @@ from vllm.v1.worker.gpu.input_batch import (
expand_idx_mapping, expand_idx_mapping,
get_num_sampled_and_rejected, get_num_sampled_and_rejected,
post_update, post_update,
post_update_pool,
prepare_pos_seq_lens, prepare_pos_seq_lens,
prepare_prefill_inputs, prepare_prefill_inputs,
) )
@@ -77,6 +79,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.model_states import ModelState from vllm.v1.worker.gpu.model_states import ModelState
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
@@ -119,7 +122,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype self.cache_config.cache_dtype
] ]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size() self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
@@ -217,6 +219,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured. # KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# For transferring state from execute_model to subsequent sample_tokens call. # For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None self.execute_model_state: tuple | None = None
@@ -224,9 +230,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len self.req_states.max_model_len = max_model_len
@staticmethod def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
def get_supported_tasks() -> tuple[str]: tasks: list[SupportedTask] = []
return ("generate",) if self.model_config.runner_type == "generate":
tasks.append("generate")
if self.pooling_runner is not None:
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
return tuple(tasks)
def load_model(self, *args, **kwargs) -> None: def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
@@ -263,6 +273,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Initialize the components that require the model. # Initialize the components that require the model.
self.model_state = ModelState(self.vllm_config, self.model, self.device) self.model_state = ModelState(self.vllm_config, self.model, self.device)
if self.is_pooling_model:
self.pooling_runner = PoolingRunner(self.model)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
@@ -388,16 +400,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
expanded_local_pos, expanded_local_pos,
) )
@torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
assert self.pooling_runner is not None
self.pooling_runner.dummy_pooler_run(hidden_states)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run( hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True self.max_num_tokens, skip_attn=True
) )
# Only run sampler on last PP rank (non-last ranks return None). # Only run sampler/pooler on last PP rank (non-last ranks return None).
if self.is_last_pp_rank: if self.is_last_pp_rank:
assert sample_hidden_states is not None assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states) if self.pooling_runner is None:
self._dummy_sampler_run(sample_hidden_states)
else:
self._dummy_pooler_run(hidden_states)
if self.speculator is not None: if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp( num_tokens_across_dp = make_num_tokens_across_dp(
@@ -505,7 +525,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None assert new_req_data.prefill_token_ids is not None
assert new_req_data.sampling_params is not None
req_id = new_req_data.req_id req_id = new_req_data.req_id
prompt_len = len(new_req_data.prompt_token_ids) prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request( self.req_states.add_request(
@@ -523,14 +542,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.block_tables.append_block_ids( self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True req_index, new_req_data.block_ids, overwrite=True
) )
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request) self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
if new_req_data.sampling_params is not None:
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
if scheduler_output.scheduled_new_reqs: if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes() self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes() self.sampler.apply_staged_writes()
@@ -1083,3 +1104,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def take_draft_token_ids(self) -> DraftTokenIds | None: def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.draft_tokens_handler.get_draft_tokens() return self.draft_tokens_handler.get_draft_tokens()
@torch.inference_mode()
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None
if not self.is_last_pp_rank:
self.postprocess_pool(input_batch)
return None
assert self.pooling_runner is not None
pooler_output, is_valid = self.pooling_runner.pool(
hidden_states, input_batch, self.req_states
)
self.postprocess_pool(input_batch)
# Build the model runner output.
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
kv_connector_output=kv_connector_output,
)
async_output = AsyncPoolingOutput(
model_runner_output=model_runner_output,
pooler_output=pooler_output,
is_valid=is_valid,
main_stream=self.main_stream,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def postprocess_pool(self, input_batch: InputBatch) -> None:
# Update the number of computed tokens.
post_update_pool(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
input_batch.query_start_loc,
)
# Update the number of computed prefill tokens.
idx_mapping_np = input_batch.idx_mapping_np
computed_prefill = self.req_states.num_computed_prefill_tokens
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)

View File

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
from vllm.tasks import PoolingTask
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.states import RequestState
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
# on decoder-only models. How to support other pooling tasks and models
# is to be determined.
class PoolingRunner:
def __init__(self, model: nn.Module):
self.model = cast(VllmModelForPooling, model)
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
if not is_pooling_model(self.model):
return []
assert "embed" in self.model.pooler.get_supported_tasks()
return ["embed"]
def pool(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
req_states: RequestState,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# TODO(woosuk): Support different types of pooling tasks.
last_hidden_states = hidden_states[input_batch.logits_indices]
# TODO(woosuk): Make normalization optional.
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
is_valid = input_batch.seq_lens == prompt_len
return last_hidden_states, is_valid
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
F.normalize(hidden_states, p=2, dim=-1)
return

View File

@@ -700,6 +700,12 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors scheduler_output, intermediate_tensors
) )
if (
self.use_v2_model_runner
and self.model_runner.is_pooling_model
and output is None
):
output = self.model_runner.pool() # type: ignore
if isinstance( if isinstance(
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
): ):