[Model Runner V2] Add model states [1/N] (#35350)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -29,13 +30,11 @@ class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
@@ -88,8 +87,8 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -113,13 +112,18 @@ class CudaGraphManager:
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
"inputs_embeds": (
|
||||
inputs_embeds[:num_tokens] if inputs_embeds is not None else None
|
||||
),
|
||||
# NOTE: Values returned by `prepare_dummy_inputs` will override the
|
||||
# default values above.
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
@@ -143,11 +147,7 @@ class CudaGraphManager:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -164,9 +164,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
model_inputs=model_inputs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
@@ -178,9 +176,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -206,11 +202,8 @@ class CudaGraphManager:
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
@@ -235,9 +228,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -256,18 +247,14 @@ class CudaGraphManager:
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model(**model_inputs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -278,8 +265,8 @@ class CudaGraphManager:
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
|
||||
@@ -65,8 +65,6 @@ class InputBatch:
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor | None
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
@@ -143,7 +141,6 @@ class InputBatch:
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
|
||||
@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.model_states import ModelState
|
||||
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
@@ -140,14 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_states = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
@@ -212,7 +204,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
self.vllm_config,
|
||||
self.uses_mrope,
|
||||
self.use_aux_hidden_state_outputs,
|
||||
self.device,
|
||||
)
|
||||
@@ -271,6 +262,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
prepare_communication_buffer_for_model(self.speculator)
|
||||
|
||||
# Initialize the components that require the model.
|
||||
self.model_state = ModelState(self.vllm_config, self.model, self.device)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@@ -481,16 +475,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
model_state=self.model_state,
|
||||
input_buffers=self.input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=self.block_tables,
|
||||
attn_groups=self.attn_groups,
|
||||
@@ -554,14 +545,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
self.model_state.add_request(req_index, new_req_data)
|
||||
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
@@ -577,8 +561,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes()
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
self.model_state.apply_staged_writes()
|
||||
|
||||
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# Add new blocks for the existing requests.
|
||||
@@ -692,15 +675,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.prepare_mrope_positions(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
logits_indices = combine_sampled_and_draft_tokens(
|
||||
@@ -744,10 +718,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -764,7 +734,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
@@ -959,14 +928,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
if self.uses_mrope:
|
||||
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
if not skip_attn_for_dummy_run:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
# FIXME(woosuk): Fix warmup for LoRA.
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_batch.input_ids,
|
||||
"positions": input_batch.positions,
|
||||
"inputs_embeds": input_batch.inputs_embeds,
|
||||
# NOTE: Values returned by `prepare_inputs` will override the default
|
||||
# values above.
|
||||
**self.model_state.prepare_inputs(input_batch, self.req_states),
|
||||
}
|
||||
if not self.is_first_pp_rank:
|
||||
# Update for non-first PP ranks.
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
|
||||
# Run model.
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# Use explicit cudagraph replay for FULL mode.
|
||||
@@ -983,20 +962,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
else:
|
||||
# For piecewise and eager mode, just call model().
|
||||
positions = input_batch.positions
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
|
||||
if self.is_first_pp_rank:
|
||||
input_ids = input_batch.input_ids
|
||||
inputs_embeds = input_batch.inputs_embeds
|
||||
assert intermediate_tensors is None
|
||||
else:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
has_lora=self.lora_config is not None,
|
||||
@@ -1012,12 +977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
model_output = self.model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
|
||||
74
vllm/v1/worker/gpu/model_states.py
Normal file
74
vllm/v1/worker/gpu/model_states.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class ModelState:
|
||||
def __init__(self, vllm_config: VllmConfig, model: nn.Module, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_state = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
if self.uses_mrope:
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
self.mrope_state.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self.uses_mrope:
|
||||
self.mrope_state.apply_staged_writes()
|
||||
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
if not self.uses_mrope:
|
||||
# Common case (1D positions).
|
||||
return {}
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
self.mrope_state.prepare_mrope_positions(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
req_states.prefill_len.gpu,
|
||||
req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
mrope_positions = self.mrope_state.mrope_positions[
|
||||
:, : input_batch.num_tokens_after_padding
|
||||
]
|
||||
return {"positions": mrope_positions}
|
||||
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
if not self.uses_mrope:
|
||||
return {}
|
||||
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
|
||||
return {"positions": mrope_positions}
|
||||
Reference in New Issue
Block a user