[Model Runner V2] Support pooling models (#35120)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
class AsyncPoolingOutput(AsyncModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
pooler_output: torch.Tensor,
|
||||
is_valid: torch.Tensor | None,
|
||||
main_stream: torch.cuda.Stream,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
copy_event: torch.cuda.Event,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.pooler_output = pooler_output
|
||||
self.is_valid = is_valid
|
||||
self.copy_event = copy_event
|
||||
|
||||
with stream(copy_stream, main_stream):
|
||||
copy_stream.wait_stream(main_stream)
|
||||
self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True)
|
||||
if self.is_valid is not None:
|
||||
self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True)
|
||||
else:
|
||||
self.is_valid_cpu = None
|
||||
self.copy_event.record(copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
pooler_output = self.pooler_output_cpu.unbind(dim=0)
|
||||
if self.is_valid_cpu is not None:
|
||||
is_valid_cpu = self.is_valid_cpu.tolist()
|
||||
for i, is_valid in enumerate(is_valid_cpu):
|
||||
if not is_valid:
|
||||
pooler_output[i] = None
|
||||
self.model_runner_output.pooler_output = pooler_output
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
|
||||
return x.to("cpu", non_blocking=True).numpy()
|
||||
|
||||
|
||||
@@ -499,6 +499,38 @@ def post_update(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _post_update_pool_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
query_start_loc_ptr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
query_start = tl.load(query_start_loc_ptr + batch_id)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
|
||||
|
||||
|
||||
def post_update_pool(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_pool_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
query_start_loc,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _expand_idx_mapping_kernel(
|
||||
idx_mapping_ptr,
|
||||
|
||||
@@ -38,13 +38,14 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_slot_mappings_by_layer,
|
||||
get_kv_cache_spec,
|
||||
@@ -66,6 +67,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
expand_idx_mapping,
|
||||
get_num_sampled_and_rejected,
|
||||
post_update,
|
||||
post_update_pool,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
@@ -77,6 +79,7 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.model_states import ModelState
|
||||
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
|
||||
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
@@ -119,7 +122,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
@@ -217,6 +219,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
# Pooling models.
|
||||
self.is_pooling_model = self.model_config.runner_type == "pooling"
|
||||
self.pooling_runner: PoolingRunner | None = None
|
||||
|
||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||
self.execute_model_state: tuple | None = None
|
||||
|
||||
@@ -224,9 +230,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
|
||||
@staticmethod
|
||||
def get_supported_tasks() -> tuple[str]:
|
||||
return ("generate",)
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks: list[SupportedTask] = []
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.append("generate")
|
||||
if self.pooling_runner is not None:
|
||||
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
|
||||
return tuple(tasks)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
@@ -263,6 +273,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Initialize the components that require the model.
|
||||
self.model_state = ModelState(self.vllm_config, self.model, self.device)
|
||||
if self.is_pooling_model:
|
||||
self.pooling_runner = PoolingRunner(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
@@ -388,16 +400,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
assert self.pooling_runner is not None
|
||||
self.pooling_runner.dummy_pooler_run(hidden_states)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens, skip_attn=True
|
||||
)
|
||||
|
||||
# Only run sampler on last PP rank (non-last ranks return None).
|
||||
# Only run sampler/pooler on last PP rank (non-last ranks return None).
|
||||
if self.is_last_pp_rank:
|
||||
assert sample_hidden_states is not None
|
||||
if self.pooling_runner is None:
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
else:
|
||||
self._dummy_pooler_run(hidden_states)
|
||||
|
||||
if self.speculator is not None:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
@@ -505,7 +525,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
assert new_req_data.sampling_params is not None
|
||||
req_id = new_req_data.req_id
|
||||
prompt_len = len(new_req_data.prompt_token_ids)
|
||||
self.req_states.add_request(
|
||||
@@ -523,13 +542,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
)
|
||||
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
|
||||
|
||||
if new_req_data.sampling_params is not None:
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
@@ -1083,3 +1104,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
return self.draft_tokens_handler.get_draft_tokens()
|
||||
|
||||
@torch.inference_mode()
|
||||
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
|
||||
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
self.postprocess_pool(input_batch)
|
||||
return None
|
||||
|
||||
assert self.pooling_runner is not None
|
||||
pooler_output, is_valid = self.pooling_runner.pool(
|
||||
hidden_states, input_batch, self.req_states
|
||||
)
|
||||
self.postprocess_pool(input_batch)
|
||||
|
||||
# Build the model runner output.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
async_output = AsyncPoolingOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
pooler_output=pooler_output,
|
||||
is_valid=is_valid,
|
||||
main_stream=self.main_stream,
|
||||
copy_stream=self.output_copy_stream,
|
||||
copy_event=self.output_copy_event,
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
def postprocess_pool(self, input_batch: InputBatch) -> None:
|
||||
# Update the number of computed tokens.
|
||||
post_update_pool(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
input_batch.query_start_loc,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
computed_prefill = self.req_states.num_computed_prefill_tokens
|
||||
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
|
||||
np.minimum(
|
||||
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
||||
)
|
||||
|
||||
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
|
||||
# on decoder-only models. How to support other pooling tasks and models
|
||||
# is to be determined.
|
||||
class PoolingRunner:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = cast(VllmModelForPooling, model)
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
if not is_pooling_model(self.model):
|
||||
return []
|
||||
assert "embed" in self.model.pooler.get_supported_tasks()
|
||||
return ["embed"]
|
||||
|
||||
def pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO(woosuk): Support different types of pooling tasks.
|
||||
last_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
# TODO(woosuk): Make normalization optional.
|
||||
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
|
||||
|
||||
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
|
||||
is_valid = input_batch.seq_lens == prompt_len
|
||||
return last_hidden_states, is_valid
|
||||
|
||||
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
F.normalize(hidden_states, p=2, dim=-1)
|
||||
return
|
||||
@@ -700,6 +700,12 @@ class Worker(WorkerBase):
|
||||
output = self.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if (
|
||||
self.use_v2_model_runner
|
||||
and self.model_runner.is_pooling_model
|
||||
and output is None
|
||||
):
|
||||
output = self.model_runner.pool() # type: ignore
|
||||
if isinstance(
|
||||
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user