[Model Runner V2] Change bookkeeping logic in preparation for spec decoding (#29194)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-23 09:42:52 -08:00
committed by GitHub
parent 6fb0215eee
commit 7f12c82fa6
6 changed files with 269 additions and 140 deletions

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np
import torch import torch
from vllm.v1.outputs import ( from vllm.v1.outputs import (
@@ -18,7 +17,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
self, self,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput, sampler_output: SamplerOutput,
num_sampled_tokens: np.ndarray, num_sampled_tokens: torch.Tensor,
copy_stream: torch.cuda.Stream, copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event, copy_event: torch.cuda.Event,
): ):
@@ -52,6 +51,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
) )
else: else:
self.logprobs_tensors = None self.logprobs_tensors = None
self.num_sampled_tokens = num_sampled_tokens.to("cpu", non_blocking=True)
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict: if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items(): for k, v in self.model_runner_output.prompt_logprobs_dict.items():
@@ -63,6 +63,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
def get_output(self) -> ModelRunnerOutput: def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize() self.copy_event.synchronize()
num_sampled_tokens_np = self.num_sampled_tokens.numpy()
# NOTE(woosuk): The following code is to ensure compatibility with # NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner. # the existing model runner.
@@ -71,7 +72,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist() sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_reqs = len(sampled_token_ids) num_reqs = len(sampled_token_ids)
for i in range(num_reqs): for i in range(num_reqs):
del sampled_token_ids[i][self.num_sampled_tokens[i] :] del sampled_token_ids[i][num_sampled_tokens_np[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None: if self.logprobs_tensors is not None:

View File

@@ -3,6 +3,7 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, cast from typing import Any, cast
import numpy as np
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
@@ -145,8 +146,9 @@ def build_attn_metadata(
num_reqs: int, num_reqs: int,
num_tokens: int, num_tokens: int,
query_start_loc: CpuGpuBuffer, query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer, seq_lens: torch.Tensor,
num_computed_tokens_cpu: torch.Tensor, seq_lens_np: np.ndarray,
num_computed_tokens_cpu: torch.Tensor | None,
block_tables: Sequence[torch.Tensor], block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
@@ -154,9 +156,9 @@ def build_attn_metadata(
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max()) max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs] seq_lens = seq_lens[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs] seq_lens_cpu = torch.from_numpy(seq_lens_np)
max_seq_len = int(seq_lens.np[:num_reqs].max()) max_seq_len = int(seq_lens_np.max())
attn_metadata: dict[str, Any] = {} attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups kv_cache_groups = kv_cache_config.kv_cache_groups
@@ -167,7 +169,7 @@ def build_attn_metadata(
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu, query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu, seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu, seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu,

View File

@@ -101,14 +101,13 @@ class CudaGraphManager:
# Prepare dummy inputs. # Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size] input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size] positions = input_buffers.positions[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1) input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu() input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len input_buffers.seq_lens[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0 input_buffers.seq_lens[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables] input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size] slot_mappings = block_tables.slot_mappings[:, :batch_size]
@@ -119,6 +118,7 @@ class CudaGraphManager:
num_tokens=batch_size, num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc, query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens, seq_lens=input_buffers.seq_lens,
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
num_computed_tokens_cpu=None, # FIXME num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables, block_tables=input_block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,

View File

@@ -32,9 +32,9 @@ class InputBuffers:
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64) self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# Structured outputs. # Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
@@ -107,13 +107,15 @@ class InputBatch:
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1] query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1] query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len # seq_len equals to query_len
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
input_buffers.seq_lens.np[num_reqs:] = 0 seq_lens_np[-1] += num_tokens % num_reqs
seq_lens_np = input_buffers.seq_lens.np[:num_reqs] input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs] input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens) input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
positions = input_buffers.positions.copy_to_gpu(num_tokens) positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None) # attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
return cls( return cls(
@@ -141,27 +143,25 @@ class InputBatch:
[ [
types.none( types.none(
types.int32[:], # idx_mapping types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens types.int32[:], # num_scheduled_tokens
types.int32[:, :], # prefill_token_ids
types.int32[:], # num_computed_prefill_tokens
types.int32[:], # prefill_len
types.int32[:], # input_ids types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc types.int32[:], # query_start_loc
types.int32[:], # seq_lens
) )
], ],
nopython=True, nopython=True,
cache=True, cache=True,
) )
def _prepare_inputs( def _prepare_prefill_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B] num_scheduled_tokens: np.ndarray, # [B]
prefill_token_ids: np.ndarray, # [N, max_model_len]
num_computed_prefill_tokens: np.ndarray, # [N]
prefill_len: np.ndarray, # [N]
input_ids: np.ndarray, # [num_input_tokens] input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1] query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None: ) -> None:
num_reqs = num_scheduled_tokens.shape[0] num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0 query_start_loc[0] = 0
@@ -170,62 +170,112 @@ def _prepare_inputs(
for i in range(num_reqs): for i in range(num_reqs):
req_idx = idx_mapping[i] req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i] query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len start = num_computed_prefill_tokens[req_idx]
seq_lens[i] = end end = min(start + query_len, prefill_len[req_idx])
n = end - start
start_idx = cu_num_tokens start_idx = cu_num_tokens
end_idx = start_idx + query_len input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end]
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx cu_num_tokens = start_idx + query_len
query_start_loc[i + 1] = cu_num_tokens query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs. # Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels # Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that # like FlashAttention requires that
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens) query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_inputs( def prepare_prefill_inputs(
idx_mapping: np.ndarray, idx_mapping: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
total_num_tokens: int,
prefill_token_ids: np.ndarray,
num_computed_prefill_tokens: np.ndarray,
prefill_len: np.ndarray,
input_ids: CpuGpuBuffer, input_ids: CpuGpuBuffer,
positions: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer, query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_tokens: int,
) -> None: ) -> None:
_prepare_inputs( _prepare_prefill_inputs(
idx_mapping, idx_mapping,
prefill_token_ids,
num_computed_tokens,
num_scheduled_tokens, num_scheduled_tokens,
prefill_token_ids,
num_computed_prefill_tokens,
prefill_len,
input_ids.np, input_ids.np,
positions.np,
query_start_loc.np, query_start_loc.np,
seq_lens.np,
) )
input_ids.copy_to_gpu(num_tokens) input_ids.copy_to_gpu(total_num_tokens)
positions.copy_to_gpu(num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed # tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode. # for full CUDA graph mode.
query_start_loc.copy_to_gpu() query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu()
return
@triton.jit @triton.jit
def _combine_last_token_ids_kernel( def _prepare_pos_seq_lens_kernel(
pos_ptr,
seq_lens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
num_computed_tokens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_id == num_reqs:
# Pad unused seq_lens as 0 for full CUDA graphs.
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
seq_len = num_computed_tokens + query_len
tl.store(seq_lens_ptr + req_id, seq_len)
for i in tl.range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
pos = num_computed_tokens + block
tl.store(pos_ptr + start + block, pos, mask=mask)
def prepare_pos_seq_lens(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
num_computed_tokens: torch.Tensor,
pos: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
# NOTE(woosuk): We do +1 because the last thread block is used
# to pad unused seq_lens as 0 for full CUDA graphs.
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
pos,
seq_lens,
idx_mapping,
query_start_loc,
num_computed_tokens,
seq_lens.shape[0],
BLOCK_SIZE=1024,
)
@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr, input_ids_ptr,
idx_mapping_ptr, idx_mapping_ptr,
last_token_ids_ptr, last_sampled_tokens_ptr,
query_start_loc_ptr, query_start_loc_ptr,
seq_lens_ptr, seq_lens_ptr,
prefill_len_ptr, prefill_len_ptr,
@@ -239,26 +289,56 @@ def _combine_last_token_ids_kernel(
# Handling prefill tokens. # Handling prefill tokens.
return return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx) last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
end = tl.load(query_start_loc_ptr + batch_idx + 1) end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id) tl.store(input_ids_ptr + end - 1, last_token_id)
def combine_last_token_ids( def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor, input_ids: torch.Tensor,
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
last_token_ids: torch.Tensor, last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
prefill_len: torch.Tensor, prefill_len: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = seq_lens.shape[0] num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs,)]( _combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids, input_ids,
idx_mapping, idx_mapping,
last_token_ids, last_sampled_tokens,
query_start_loc, query_start_loc,
seq_lens, seq_lens,
prefill_len, prefill_len,
) )
return input_ids return input_ids
@triton.jit
def _update_num_computed_tokens_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
query_start_loc_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
n = tl.load(num_computed_tokens_ptr + req_state_idx)
tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len)
def update_num_computed_tokens(
idx_mapping: torch.Tensor,
num_computed_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_update_num_computed_tokens_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
query_start_loc,
)

View File

@@ -39,8 +39,10 @@ from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
InputBatch, InputBatch,
InputBuffers, InputBuffers,
combine_last_token_ids, combine_sampled_and_draft_tokens,
prepare_inputs, prepare_pos_seq_lens,
prepare_prefill_inputs,
update_num_computed_tokens,
) )
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
@@ -179,6 +181,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config, self.vllm_config,
self.device, self.device,
) )
# TODO(woosuk): Support other backends.
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
init_kv_cache( init_kv_cache(
@@ -196,8 +201,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
slot_mappings = self.block_tables.get_dummy_slot_mappings( slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens input_batch.num_tokens
) )
num_computed_tokens_cpu = torch.zeros( num_computed_tokens = torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device="cpu" input_batch.num_reqs, dtype=torch.int32, device=self.device
) )
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders, attn_metadata_builders=self.attn_metadata_builders,
@@ -205,7 +210,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=input_batch.num_tokens, num_tokens=input_batch.num_tokens,
query_start_loc=self.input_buffers.query_start_loc, query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens, seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu, seq_lens_np=input_batch.seq_lens_np,
num_computed_tokens_cpu=num_computed_tokens,
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
@@ -368,6 +374,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks[i].append(x + len(block_ids)) cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids) new_block_ids[i].extend(block_ids)
overwrite.append(True) overwrite.append(True)
# Update the GPU tensors for request states.
if scheduler_output.scheduled_new_reqs:
self.req_states.prefill_len.copy_to_gpu()
# Add new blocks for the existing requests. # Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
@@ -421,46 +430,60 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping) block_tables = self.block_tables.gather_block_tables(idx_mapping)
prepare_inputs( # Copy prefill tokens from CPU to GPU and get query_start_loc.
prepare_prefill_inputs(
idx_mapping_np, idx_mapping_np,
self.req_states.prefill_token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens, num_scheduled_tokens,
self.input_buffers.input_ids,
self.input_buffers.positions,
self.input_buffers.query_start_loc,
self.input_buffers.seq_lens,
num_tokens, num_tokens,
self.req_states.prefill_token_ids,
self.req_states.num_computed_prefill_tokens,
self.req_states.prefill_len.np,
self.input_buffers.input_ids,
self.input_buffers.query_start_loc,
) )
query_start_loc = self.input_buffers.query_start_loc query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = query_start_loc.np[: num_reqs + 1] query_start_loc_np = query_start_loc.np[: num_reqs + 1]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
# Some input token ids are directly read from the last sampled tokens. # Prepare positions and seq_lens.
combine_last_token_ids( prepare_pos_seq_lens(
idx_mapping,
query_start_loc_gpu,
self.req_states.num_computed_tokens,
self.input_buffers.positions,
self.input_buffers.seq_lens,
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens.
combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu, self.input_buffers.input_ids.gpu,
idx_mapping, idx_mapping,
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
query_start_loc_gpu, query_start_loc_gpu,
seq_lens_gpu, seq_lens,
self.req_states.prefill_len.copy_to_gpu(), self.req_states.prefill_len.gpu,
) )
# Compute slot mappings: [num_kv_cache_groups, num_tokens] # Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens] query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
)
num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np]
) )
# Logits indices to sample next token from. # Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1 logits_indices = query_start_loc_gpu[1:] - 1
# Get num_computed_tokens.
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
# num_computed_tokens_cpu. This works for most cases.
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping]
# HACK(woosuk): Only GPU has the exact seq_lens because at this point
# CPU does not know how many draft tokens are accepted/rejected in the
# previous step. Therefore, we use max_model_len to be safe.
# NOTE(woosuk): This only works for FA3 backend.
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
# Layer name -> attention metadata. # Layer name -> attention metadata.
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders, attn_metadata_builders=self.attn_metadata_builders,
@@ -468,14 +491,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=num_tokens, num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc, query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens, seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu, seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=num_computed_tokens,
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
@@ -486,7 +510,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_after_padding=num_tokens_after_padding, num_tokens_after_padding=num_tokens_after_padding,
query_start_loc=query_start_loc_gpu, query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np, query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens_gpu, seq_lens=seq_lens,
seq_lens_np=seq_lens_np, seq_lens_np=seq_lens_np,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
@@ -500,11 +524,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch, input_batch: InputBatch,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None, grammar_output: GrammarOutput | None,
) -> SamplerOutput: ) -> tuple[SamplerOutput, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None: if grammar_output is not None:
# Apply grammar bitmask to the logits in-place. # Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
with async_barrier(self.structured_outputs_event): with async_barrier(self.structured_outputs_event):
apply_grammar_bitmask( apply_grammar_bitmask(
logits, logits,
@@ -513,8 +538,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output.grammar_bitmask, grammar_output.grammar_bitmask,
self.input_buffers, self.input_buffers,
) )
sampler_output = self.sampler(logits, sampling_metadata) sampler_output = self.sampler(logits, sampling_metadata)
return sampler_output # Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
is_chunked_prefilling = input_batch.seq_lens < prefill_len
num_sampled = (~is_chunked_prefilling).int()
return sampler_output, num_sampled
def compute_prompt_logprobs( def compute_prompt_logprobs(
self, self,
@@ -527,11 +558,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No request asks for prompt logprobs. # No request asks for prompt logprobs.
return {} return {}
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
prompt_lens = self.req_states.prompt_len[idx_mapping_np] prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not # NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs. # needed for prompt logprobs.
includes_prompt = num_computed_tokens < prompt_lens - 1 computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np]
includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt # NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip. # logprobs must have been computed before preemption. Skip.
resumed_after_prompt = ( resumed_after_prompt = (
@@ -550,8 +581,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
token_ids[n - 1] = 0 token_ids[n - 1] = 0
# Handle chunked prompts. # Handle chunked prompts.
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs] pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = seq_lens < prompt_lens is_prompt_chunked = pos_after_step < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids prefill_token_ids = self.req_states.prefill_token_ids
query_start_loc = self.input_buffers.query_start_loc.np query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
@@ -561,7 +592,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue continue
# The prompt is chunked. Get the next prompt token. # The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i] req_idx = input_batch.idx_mapping_np[i]
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]]) next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]])
idx = int(query_start_loc[i + 1] - 1) idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token. # Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation. # NOTE(woosuk): This triggers a GPU operation.
@@ -617,48 +648,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def postprocess( def postprocess(
self, self,
sampler_output: SamplerOutput,
prompt_logprobs_dict: dict[str, LogprobsTensors],
input_batch: InputBatch, input_batch: InputBatch,
) -> AsyncOutput | ModelRunnerOutput: sampled_tokens: torch.Tensor,
# Store the last sampled token ids. num_sampled: torch.Tensor,
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = ( ) -> None:
sampler_output.sampled_token_ids # Update the number of computed tokens.
update_num_computed_tokens(
input_batch.idx_mapping,
self.req_states.num_computed_tokens,
input_batch.query_start_loc,
) )
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
idx_mapping_np = input_batch.idx_mapping_np idx_mapping_np = input_batch.idx_mapping_np
is_chunked_prefilling = ( computed_prefill = self.req_states.num_computed_prefill_tokens
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np] # TODO(woosuk): Simplify this.
) computed_prefill[idx_mapping_np] = np.minimum(
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) computed_prefill[idx_mapping_np] + input_batch.num_scheduled_tokens,
# Increment the number of tokens. self.req_states.prefill_len.np[idx_mapping_np],
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens
) )
model_runner_output = ModelRunnerOutput( # Store the last sampled token ids.
req_ids=input_batch.req_ids, last_sampled = sampled_tokens
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled
sampled_token_ids=None, # type: ignore
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def get_cudagraph_and_dp_padding( def get_cudagraph_and_dp_padding(
self, self,
@@ -782,6 +792,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
else: else:
# Run PyTorch model in eager mode. # Run PyTorch model in eager mode.
# TODO(woosuk): Support piecewise CUDA graph.
with set_forward_context( with set_forward_context(
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, self.vllm_config,
@@ -807,13 +818,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.execute_model_state = None # type: ignore self.execute_model_state = None # type: ignore
assert sampling_metadata is not None assert sampling_metadata is not None
sampler_output = self.sample( sampler_output, num_sampled_tokens = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output hidden_states, input_batch, sampling_metadata, grammar_output
) )
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
output = self.postprocess(
sampler_output, # Prepare the model runner output.
prompt_logprobs_dict, model_runner_output = ModelRunnerOutput(
input_batch, req_ids=input_batch.req_ids,
# NOTE(woosuk): req_id_to_index is unused in this model runner.
# Only for compatibility with the existing model runner and scheduler.
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None, # type: ignore
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
) )
return output async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
# Postprocess results and update request states.
# NOTE: This is intentionally done after creating the AsyncOutput,
# ensuring that `copy_event` is recorded before calling postprocess.
# This sequencing may slightly reduce latency as async D2H copy does not
# need to wait for the postprocess to finish.
self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()

View File

@@ -85,8 +85,12 @@ class RequestState:
dtype=np.int32, dtype=np.int32,
) )
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) # Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Last sampled tokens. # Last sampled tokens.
self.last_sampled_tokens = torch.zeros( self.last_sampled_tokens = torch.zeros(
@@ -145,7 +149,10 @@ class RequestState:
) )
self.prefill_len.np[req_idx] = prefill_len self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
self.num_tokens[req_idx] = prefill_len
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
# Optimize this.
self.num_computed_tokens[req_idx] = num_computed_tokens self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None: if lora_request is not None: