[Model Runner V2] Remove async barrier (#32083)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-11 20:24:30 -08:00
committed by GitHub
parent 19504ac07f
commit 025a32f9ed
13 changed files with 589 additions and 461 deletions

View File

@@ -6,9 +6,8 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class BlockTables:
@@ -26,19 +25,16 @@ class BlockTables:
self.max_model_len = max_model_len
self.device = device
if not is_uva_available():
raise RuntimeError("UVA is not available")
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[UvaBuffer] = []
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = UvaBuffer(
self.max_num_reqs,
max_num_blocks,
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
device=device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(
@@ -53,9 +49,8 @@ class BlockTables:
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = UvaBuffer(
self.num_kv_cache_groups,
self.max_num_reqs,
self.num_blocks = UvaBackedTensor(
(self.num_kv_cache_groups, self.max_num_reqs),
dtype=torch.int32,
)
@@ -75,13 +70,11 @@ class BlockTables:
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu = torch.tensor(
return torch.tensor(
[t.data_ptr() for t in x],
dtype=torch.uint64,
device="cpu",
pin_memory=True,
device=self.device,
)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
@@ -90,19 +83,17 @@ class BlockTables:
overwrite: bool,
) -> None:
for i in range(self.num_kv_cache_groups):
block_ids = new_block_ids[i]
num_new_blocks = len(block_ids)
if num_new_blocks == 0:
continue
# TODO(woosuk): Too many Numpy invocations. Optimize this.
start = self.num_blocks.np[i, req_index] if not overwrite else 0
end = start + num_new_blocks
if num_new_blocks == 1:
self.block_tables[i].np[req_index, start] = block_ids[0]
else:
self.block_tables[i].np[req_index, start:end] = block_ids
self.num_blocks.np[i, req_index] = end
block_ids = new_block_ids[i]
self.block_tables[i].stage_write(req_index, start, block_ids)
self.num_blocks.np[i, req_index] = start + len(block_ids)
def apply_staged_writes(self) -> None:
# TODO(woosuk): This can be inefficient since it launches one kernel per
# block table. Implement a kernel to handle all block tables at once.
for block_table in self.block_tables:
block_table.apply_write()
self.num_blocks.copy_to_uva()
def gather_block_tables(
self,
@@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)
class UvaBuffer:
def __init__(self, *size, dtype: torch.dtype):
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)

View File

@@ -0,0 +1,218 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
class UvaBuffer:
def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
if not is_uva_available():
raise RuntimeError("UVA is not available")
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
class UvaBufferPool:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
):
self.size = size
self.dtype = dtype
self.max_concurrency = max_concurrency
# UVA buffers for concurrency
self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)]
# Current buffer index
self._curr = 0
def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor:
# Round robin to the next buffer.
self._curr = (self._curr + 1) % self.max_concurrency
buf = self._uva_bufs[self._curr]
# CPU-to-CPU copy
dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np
n = len(x)
dst[:n] = x
return buf.uva[:n]
def copy_to_gpu(
self,
x: torch.Tensor | np.ndarray,
out: torch.Tensor | None = None,
) -> torch.Tensor:
uva = self.copy_to_uva(x)
if out is None:
# CPU-to-GPU copy
return uva.clone()
# CPU-to-GPU copy
return out.copy_(uva, non_blocking=True)
class UvaBackedTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
):
self.dtype = dtype
self.max_concurrency = max_concurrency
# Source of truth
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
self.np = self.cpu.numpy()
# Buffers for concurrency
self.pool = UvaBufferPool(size, dtype, max_concurrency)
self.gpu = self.pool.copy_to_uva(self.np)
def copy_to_uva(self, n: int | None = None) -> torch.Tensor:
# CPU-to-CPU copy
self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np)
return self.gpu
class StagedWriteTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
device: torch.device,
max_concurrency: int = 2,
uva_instead_of_gpu: bool = False,
):
if dtype not in [torch.int32, torch.int64]:
raise ValueError(
f"Unsupported dtype {dtype}: should be either int32 or int64"
)
self.num_rows = size if isinstance(size, int) else size[0]
self.dtype = dtype
self.max_concurrency = max_concurrency
if not uva_instead_of_gpu:
# Create a GPU tensor (default)
self.gpu = torch.zeros(size, dtype=dtype, device=device)
else:
# For a large but not-frequently-accessed tensor, we can use UVA instead of
# GPU to save GPU memory
self._uva_buf = UvaBuffer(size, dtype)
self.gpu = self._uva_buf.uva
self._staged_write_indices: list[int] = []
self._staged_write_starts: list[int] = []
self._staged_write_contents: list[int] = []
self._staged_write_cu_lens: list[int] = []
self.write_indices = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
self.write_starts = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
init_size = next_power_of_2(self.num_rows)
self.write_contents = UvaBufferPool(
init_size, dtype=dtype, max_concurrency=max_concurrency
)
self.write_cu_lens = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
def stage_write(self, index: int, start: int, x: list[int]) -> None:
assert index >= 0
assert start >= 0
if not x:
return
self._staged_write_indices.append(index)
self._staged_write_starts.append(start)
self._staged_write_contents.extend(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def stage_write_elem(self, index: int, x: int) -> None:
assert index >= 0
self._staged_write_indices.append(index)
self._staged_write_starts.append(0)
self._staged_write_contents.append(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def apply_write(self) -> None:
n = len(self._staged_write_indices)
if n == 0:
return
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
# Special handling for write_contents
diff_len = len(self._staged_write_contents)
assert isinstance(self.write_contents.size, int)
if diff_len > self.write_contents.size:
# Re-allocate a larger buffer for the write_contents
new_size = next_power_of_2(diff_len)
self.write_contents = UvaBufferPool(
new_size, dtype=self.dtype, max_concurrency=self.max_concurrency
)
# NOTE(woosuk): Since the previous write_contents buffer is released,
# we perform a synchronization here to ensure that all data transfers
# involving the old buffer have finished before allocating a new one.
# This prevents potential race conditions. The slight overhead is
# negligible because the reallocations are infrequent in practice.
torch.cuda.synchronize()
contents_uva = self.write_contents.copy_to_uva(self._staged_write_contents)
# Write diffs to the GPU buffer
_apply_write_kernel[(n,)](
self.gpu,
self.gpu.stride(0),
indices_uva,
starts_uva,
contents_uva,
cu_lens_uva,
BLOCK_SIZE=1024,
)
# Clear the staged writes
self.clear_staged_writes()
def clear_staged_writes(self) -> None:
self._staged_write_indices.clear()
self._staged_write_starts.clear()
self._staged_write_contents.clear()
self._staged_write_cu_lens.clear()
@triton.jit
def _apply_write_kernel(
output_ptr,
output_stride,
write_indices_ptr,
write_starts_ptr,
write_contents_ptr,
write_cu_lens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = tl.load(write_indices_ptr + pid)
start_idx = tl.load(write_starts_ptr + pid)
cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0
cu_end = tl.load(write_cu_lens_ptr + pid)
content_len = cu_end - cu_start
for i in range(0, content_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < content_len
content = tl.load(write_contents_ptr + cu_start + block, mask=mask)
tl.store(
output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask
)

View File

@@ -228,10 +228,13 @@ def prepare_inputs_to_capture(
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc = input_buffers.query_start_loc
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
query_start_loc.np[num_reqs:] = num_tokens
query_start_loc.copy_to_gpu()
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
@@ -245,8 +248,8 @@ def prepare_inputs_to_capture(
attn_metadata_builders=attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,

View File

@@ -8,8 +8,6 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
@@ -21,30 +19,17 @@ class InputBuffers:
vocab_size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.query_start_loc = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer(
max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
@dataclass
@@ -56,6 +41,8 @@ class InputBatch:
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# Identical to idx_mapping except for spec decoding.
expanded_idx_mapping: torch.Tensor
# [num_reqs]
# batch_idx -> num_scheduled_tokens
@@ -83,6 +70,7 @@ class InputBatch:
logits_indices: torch.Tensor
# [num_reqs + 1]
cu_num_logits: torch.Tensor
cu_num_logits_np: np.ndarray
@classmethod
def make_dummy(
@@ -96,33 +84,41 @@ class InputBatch:
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
input_buffers.query_start_loc.np[0] = 0
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
num_scheduled_tokens
)
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
# Pad for full CUDA graph mode.
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]
query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
input_buffers.query_start_loc[0] = 0
torch.cumsum(
seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1]
)
# Pad for full CUDA graph mode.
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
@@ -135,6 +131,7 @@ class InputBatch:
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
)
@@ -473,3 +470,38 @@ def post_update(
query_start_loc,
num_warps=1,
)
@triton.jit
def _expand_idx_mapping_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
def expand_idx_mapping(
idx_mapping: torch.Tensor,
total_num_logits: int,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
_expand_idx_mapping_kernel[(num_reqs,)](
idx_mapping,
expanded_idx_mapping,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return expanded_idx_mapping

View File

@@ -15,7 +15,6 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -24,7 +23,7 @@ from vllm.v1.outputs import (
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
get_kv_cache_spec,
@@ -32,6 +31,7 @@ from vllm.v1.worker.gpu.attn_utils import (
init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import (
get_batch_metadata_across_dp,
@@ -41,22 +41,20 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_sampled_and_draft_tokens,
expand_idx_mapping,
get_num_sampled_and_rejected,
post_update,
prepare_pos_seq_lens,
prepare_prefill_inputs,
)
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
from vllm.v1.worker.gpu.sample.metadata import (
SamplingMetadata,
expand_sampling_metadata,
)
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -81,7 +79,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.observability_config = vllm_config.observability_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.kv_cache_dtype = self.dtype
if self.cache_config.cache_dtype != "auto":
@@ -123,7 +120,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
@@ -132,12 +128,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size=self.vocab_size,
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
)
# Buffers for CPU-to-GPU copies.
self.tmp_idx_mapping = UvaBufferPool(self.max_num_reqs, torch.int32)
self.tmp_cu_num_logits = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
self.tmp_query_start_loc = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
@@ -228,16 +233,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1]
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
@@ -396,8 +398,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
if scheduler_output.scheduled_new_reqs:
self.req_states.prefill_len.copy_to_gpu()
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
@@ -409,6 +409,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_index, req_new_block_ids, overwrite=False
)
self.req_states.apply_staged_writes()
self.block_tables.apply_staged_writes()
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
@@ -431,19 +434,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32)
idx_mapping = self.tmp_idx_mapping.copy_to_gpu(idx_mapping_np)
# Get the number of draft tokens for each request.
if not scheduler_output.scheduled_spec_decode_tokens:
# No draft token scheduled (common case).
total_num_draft_tokens = 0
total_num_logits = num_reqs
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
cu_num_logits = torch.arange(
num_reqs + 1, device=self.device, dtype=torch.int32
)
expanded_idx_mapping = idx_mapping
else:
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
num_draft_tokens = np.array(
@@ -456,44 +459,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens
np.cumsum(
num_draft_tokens + 1,
out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
num_logits = num_draft_tokens + 1
cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
cu_num_logits_np[0] = 0
np.cumsum(num_logits, out=cu_num_logits_np[1:])
cu_num_logits = self.tmp_cu_num_logits.copy_to_gpu(cu_num_logits_np)
expanded_idx_mapping = expand_idx_mapping(
idx_mapping,
total_num_logits,
cu_num_logits,
max_expand_len=self.num_speculative_steps + 1,
)
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Get query_start_loc.
np.cumsum(
num_scheduled_tokens,
out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
)
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
self.input_buffers.query_start_loc.copy_to_gpu()
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc_np[num_reqs + 1 :] = num_tokens
self.tmp_query_start_loc.copy_to_gpu(
query_start_loc_np,
out=self.input_buffers.query_start_loc,
)
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
# Get prefill tokens.
prepare_prefill_inputs(
self.input_buffers.input_ids,
self.req_states.next_prefill_tokens,
idx_mapping,
query_start_loc_gpu,
query_start_loc,
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens,
self.req_states.num_computed_tokens.gpu,
)
# Prepare positions and seq_lens.
prepare_pos_seq_lens(
idx_mapping,
query_start_loc_gpu,
self.req_states.num_computed_tokens,
query_start_loc,
self.req_states.num_computed_tokens.gpu,
self.input_buffers.positions,
self.input_buffers.seq_lens,
)
@@ -505,7 +517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_buffers.input_ids,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
query_start_loc,
seq_lens,
self.req_states.prefill_len.gpu,
self.req_states.draft_tokens,
@@ -515,7 +527,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
query_start_loc, self.input_buffers.positions[:num_tokens]
)
# Layer name -> attention metadata.
@@ -523,7 +535,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
@@ -539,11 +551,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
num_draft_tokens=total_num_draft_tokens,
query_start_loc=query_start_loc_gpu,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
input_ids=input_ids,
@@ -551,6 +564,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata=attn_metadata,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
)
def sample(
@@ -564,16 +578,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
assert input_batch.num_draft_tokens == 0
with async_barrier(self.structured_outputs_event):
apply_grammar_bitmask(
logits,
input_batch.req_ids,
grammar_output.structured_output_request_ids,
grammar_output.grammar_bitmask,
self.input_buffers,
)
self.structured_outputs_worker.apply_grammar_bitmask(
logits,
input_batch,
grammar_output.structured_output_request_ids,
grammar_output.grammar_bitmask,
)
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata)
@@ -641,8 +651,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Handle chunked prompts.
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids.np
query_start_loc = self.input_buffers.query_start_loc.np
prefill_token_ids = self.req_states.prefill_token_ids.gpu
query_start_loc_np = input_batch.query_start_loc_np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
@@ -650,10 +660,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]])
idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
idx = int(query_start_loc_np[i + 1] - 1)
# NOTE(woosuk): This triggers two GPU operations.
next_prompt_token = prefill_token_ids[req_idx, pos_after_step[i]]
token_ids[idx] = next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
@@ -669,8 +678,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc[i]
end_idx = query_start_loc[i + 1]
start_idx = query_start_loc_np[i]
end_idx = query_start_loc_np[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
@@ -714,7 +723,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the number of computed tokens.
post_update(
input_batch.idx_mapping,
self.req_states.num_computed_tokens,
self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens,
self.req_states.output_bin_counts,
sampled_tokens,
@@ -825,61 +834,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert intermediate_tensors is None
if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
# No need to run the model.
with async_barrier(self.input_prep_event):
self.update_states(scheduler_output)
return EMPTY_MODEL_RUNNER_OUTPUT
self.update_states(scheduler_output)
return EMPTY_MODEL_RUNNER_OUTPUT
# NOTE: Call this before the async barrier so CPU all-reduce and
# GPU execution can overlap.
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
self.get_cudagraph_and_dp_padding(scheduler_output)
)
with async_barrier(self.input_prep_event):
self.update_states(scheduler_output)
if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run.
return EMPTY_MODEL_RUNNER_OUTPUT
self.update_states(scheduler_output)
if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run.
return EMPTY_MODEL_RUNNER_OUTPUT
if not dummy_run:
# Common case.
# Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs(
scheduler_output,
num_tokens_after_padding,
)
if not dummy_run:
# Common case.
# Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs(
scheduler_output,
num_tokens_after_padding,
)
# NOTE(woosuk): Sampling metadata should be built under the async
# barrier to avoid race conditions.
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
)
if input_batch.num_draft_tokens > 0:
sampling_metadata = expand_sampling_metadata(
sampling_metadata,
input_batch.cu_num_logits,
max_expand_len=self.num_speculative_steps + 1,
)
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos
)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
else:
# No actual tokens to run. A dummy run for DP.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens_after_padding,
input_buffers=self.input_buffers,
device=self.device,
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,
)
self.prepare_dummy_attn_metadata(input_batch)
sampling_metadata = None
self._set_active_loras(*lora_inputs)
else:
# No actual tokens to run. A dummy run for DP.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens_after_padding,
input_buffers=self.input_buffers,
device=self.device,
)
self.prepare_dummy_attn_metadata(input_batch)
sampling_metadata = None
# Run model.
if cudagraph_mode == CUDAGraphMode.FULL:

View File

@@ -13,6 +13,7 @@ def _gumbel_sample_kernel(
local_max_stride,
logits_ptr,
logits_stride,
idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
@@ -20,22 +21,24 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
logits_ptr + batch_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
@@ -55,12 +58,13 @@ def _gumbel_sample_kernel(
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [num_reqs]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
@@ -88,6 +92,7 @@ def gumbel_sample(
local_max.stride(0),
logits,
logits.stride(0),
idx_mapping,
seed,
pos,
temperature,

View File

@@ -4,20 +4,23 @@ from dataclasses import dataclass
import torch
from vllm.triton_utils import tl, triton
@dataclass
class SamplingMetadata:
idx_mapping: torch.Tensor
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
min_p: torch.Tensor | None
# For penalties
repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
seeds: torch.Tensor
pos: torch.Tensor
@@ -25,11 +28,6 @@ class SamplingMetadata:
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
# For penalties
idx_mapping: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
@classmethod
def make_dummy(
cls,
@@ -37,6 +35,8 @@ class SamplingMetadata:
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
@@ -51,18 +51,19 @@ class SamplingMetadata:
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
idx_mapping=idx_mapping,
temperature=temperature,
top_p=top_p,
top_k=top_k,
@@ -70,123 +71,9 @@ class SamplingMetadata:
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@triton.jit
def _expand_sampling_metadata_kernel(
temp_ptr,
expanded_temp_ptr,
top_p_ptr,
expanded_top_p_ptr,
top_k_ptr,
expanded_top_k_ptr,
min_p_ptr,
expanded_min_p_ptr,
rep_penalty_ptr,
expanded_rep_penalty_ptr,
freq_penalty_ptr,
expanded_freq_penalty_ptr,
pres_penalty_ptr,
expanded_pres_penalty_ptr,
seeds_ptr,
expanded_seeds_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
temp = tl.load(temp_ptr + req_idx)
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
if top_p_ptr is not None:
top_p = tl.load(top_p_ptr + req_idx)
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
if top_k_ptr is not None:
top_k = tl.load(top_k_ptr + req_idx)
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
if min_p_ptr is not None:
min_p = tl.load(min_p_ptr + req_idx)
tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask)
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
seed = tl.load(seeds_ptr + req_idx)
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
def expand_sampling_metadata(
sampling_metadata: SamplingMetadata,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> SamplingMetadata:
total_num_logits = sampling_metadata.pos.shape[0]
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
expanded_temp = create_empty(sampling_metadata.temperature)
expanded_top_p = create_empty(sampling_metadata.top_p)
expanded_top_k = create_empty(sampling_metadata.top_k)
expanded_min_p = create_empty(sampling_metadata.min_p)
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
expanded_seeds = create_empty(sampling_metadata.seeds)
num_reqs = cu_num_logits.shape[0] - 1
_expand_sampling_metadata_kernel[(num_reqs,)](
sampling_metadata.temperature,
expanded_temp,
sampling_metadata.top_p,
expanded_top_p,
sampling_metadata.top_k,
expanded_top_k,
sampling_metadata.min_p,
expanded_min_p,
sampling_metadata.repetition_penalty,
expanded_repetition_penalty,
sampling_metadata.frequency_penalty,
expanded_frequency_penalty,
sampling_metadata.presence_penalty,
expanded_presence_penalty,
sampling_metadata.seeds,
expanded_seeds,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return SamplingMetadata(
temperature=expanded_temp,
top_p=expanded_top_p,
top_k=expanded_top_k,
min_p=expanded_min_p,
seeds=expanded_seeds,
repetition_penalty=expanded_repetition_penalty,
frequency_penalty=expanded_frequency_penalty,
presence_penalty=expanded_presence_penalty,
pos=sampling_metadata.pos,
max_num_logprobs=sampling_metadata.max_num_logprobs,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping=sampling_metadata.idx_mapping,
prompt_bin_mask=sampling_metadata.prompt_bin_mask,
output_bin_counts=sampling_metadata.output_bin_counts,
)

View File

@@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton
def _min_p_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
min_p = tl.load(min_p_ptr + req_idx).to(tl.float32)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0:
return
@@ -39,12 +41,17 @@ def _min_p_kernel(
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None:
def apply_min_p(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
min_p: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
logits,
logits.stride(0),
idx_mapping,
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,

View File

@@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
def _penalties_and_temperature_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
@@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel(
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
temperature = tl.load(temperature_ptr + batch_idx)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
temperature = tl.load(temperature_ptr + req_state_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0
@@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel(
logits = logits.to(tl.float32)
if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
@@ -92,11 +92,11 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.idx_mapping,
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts,

View File

@@ -71,7 +71,7 @@ class Sampler:
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p)
apply_min_p(logits, sampling_metadata.idx_mapping, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
@@ -79,6 +79,7 @@ class Sampler:
sampled = gumbel_sample(
logits,
sampling_metadata.idx_mapping,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import numpy as np
import torch
import torch.nn as nn
@@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
@@ -46,7 +44,6 @@ class EagleSpeculator:
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.pin_memory = is_pin_memory_available()
self.dtype = vllm_config.model_config.dtype
self.input_buffers = InputBuffers(
@@ -56,7 +53,6 @@ class EagleSpeculator:
vocab_size=self.vocab_size,
dtype=self.dtype,
device=device,
pin_memory=self.pin_memory,
)
self.hidden_states = torch.zeros(
self.max_num_tokens,
@@ -64,6 +60,11 @@ class EagleSpeculator:
dtype=self.dtype,
device=device,
)
self.idx_mapping = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=device,
)
self.temperature = torch.zeros(
self.max_num_reqs,
dtype=torch.float32,
@@ -140,7 +141,7 @@ class EagleSpeculator:
num_tokens_across_dp: torch.Tensor | None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
@@ -152,8 +153,9 @@ class EagleSpeculator:
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
self.temperature[:num_reqs],
self.seeds[:num_reqs],
self.idx_mapping[:num_reqs],
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
)
@@ -237,23 +239,27 @@ class EagleSpeculator:
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
temperature = self.temperature[:num_reqs]
seeds = self.seeds[:num_reqs]
pos = self.input_buffers.positions[:num_reqs]
idx_mapping = self.idx_mapping[:num_reqs]
idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(sampling_metadata.temperature)
self.seeds.copy_(sampling_metadata.seeds)
# Gather the values and copy them to the pre-allocated buffers.
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits, temperature, seeds, pos + 1, apply_temperature=True
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
)
if self.num_speculative_steps == 1:
# Early exit.
@@ -273,11 +279,8 @@ class EagleSpeculator:
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, pos
)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None:
@@ -286,8 +289,9 @@ class EagleSpeculator:
return self.draft_tokens[:num_reqs]
# Run eager mode.
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
@@ -295,7 +299,7 @@ class EagleSpeculator:
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
@@ -484,7 +488,7 @@ def prepare_eagle_decode(
input_buffers.positions,
input_hidden_states,
input_hidden_states.stride(0),
input_buffers.query_start_loc.gpu,
input_buffers.query_start_loc,
input_buffers.seq_lens,
hidden_size,
max_model_len,

View File

@@ -8,10 +8,8 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.penalties import bincount
@@ -29,7 +27,6 @@ class RequestState:
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@@ -37,7 +34,6 @@ class RequestState:
self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
@@ -47,16 +43,18 @@ class RequestState:
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
self.prefill_token_ids = UvaBuffer(
self.max_num_reqs, self.max_model_len, dtype=torch.int32
# To save GPU memory, we use UVA instead of GPU for this tensor.
self.prefill_token_ids = StagedWriteTensor(
(self.max_num_reqs, self.max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
# NOTE(woosuk): We don't use UVA for prefill_len because its GPU view
# can be used outside of update_states and prepare_inputs.
# Without async barrier, using UVA can cause race conditions.
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros(
self.num_computed_tokens = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
@@ -84,14 +82,16 @@ class RequestState:
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.min_p = self._make_param(self.max_num_reqs, torch.float32)
self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.temperature = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.min_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.repetition_penalty = UvaBackedTensor(
self.max_num_reqs, dtype=torch.float32
)
self.frequency_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(self.max_num_reqs, dtype=torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
@@ -111,13 +111,7 @@ class RequestState:
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
self._penalties_reqs: list[int] = []
@property
def num_reqs(self) -> int:
@@ -144,12 +138,9 @@ class RequestState:
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids
self.prefill_token_ids.stage_write(req_idx, 0, prefill_token_ids)
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
# Optimize this.
self.num_computed_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
@@ -169,13 +160,7 @@ class RequestState:
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
bincount(
self.prefill_token_ids.gpu[req_idx],
prefill_len,
prompt_len,
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.append(req_idx)
if sampling_params.seed is not None:
seed = sampling_params.seed
@@ -193,6 +178,22 @@ class RequestState:
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def apply_staged_writes(self) -> None:
self.prefill_len.copy_to_uva()
self.prefill_token_ids.apply_write()
self.num_computed_tokens.apply_write()
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
self.prefill_token_ids.gpu[req_idx],
int(self.prefill_len.np[req_idx]),
int(self.prompt_len[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.clear()
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
@@ -208,30 +209,25 @@ class RequestState:
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping_np]
temperature = self.temperature.copy_np_to_gpu(temperature)
temperature = self.temperature.copy_to_uva()
top_p = self.top_p.np[idx_mapping_np]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
top_p = self.top_p.copy_to_uva()[idx_mapping] if not no_top_p else None
top_k = self.top_k.np[idx_mapping_np]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
top_k = self.top_k.copy_to_uva()[idx_mapping] if not no_top_k else None
min_p = self.min_p.np[idx_mapping_np]
no_min_p = np.all(min_p == 0.0)
min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None
min_p = self.min_p.copy_to_uva() if not no_min_p else None
rep_penalty = self.repetition_penalty.np[idx_mapping_np]
rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty)
freq_penalty = self.frequency_penalty.np[idx_mapping_np]
freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty)
pres_penalty = self.presence_penalty.np[idx_mapping_np]
pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty)
rep_penalty = self.repetition_penalty.copy_to_uva()
freq_penalty = self.frequency_penalty.copy_to_uva()
pres_penalty = self.presence_penalty.copy_to_uva()
seeds = self.seeds.np[idx_mapping_np]
seeds = self.seeds.copy_np_to_gpu(seeds)
seeds = self.seeds.copy_to_uva()
num_logprobs = self.num_logprobs[idx_mapping_np]
max_num_logprobs: int | None = int(np.max(num_logprobs))
@@ -239,6 +235,7 @@ class RequestState:
max_num_logprobs = None
return SamplingMetadata(
idx_mapping=idx_mapping,
temperature=temperature,
top_p=top_p,
top_k=top_k,
@@ -246,12 +243,11 @@ class RequestState:
repetition_penalty=rep_penalty,
frequency_penalty=freq_penalty,
presence_penalty=pres_penalty,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
)
def make_lora_inputs(
@@ -272,42 +268,12 @@ class RequestState:
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param:
def __init__(
self,
size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)
@dataclass
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
class UvaBuffer:
def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype):
assert is_uva_available()
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0

View File

@@ -4,38 +4,65 @@ import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
from vllm.v1.worker.gpu.input_batch import InputBatch
def apply_grammar_bitmask(
logits: torch.Tensor,
req_ids: list[str],
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
input_buffers: InputBuffers,
) -> None:
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
class StructuredOutputsWorker:
def __init__(
self,
max_num_logits: int,
vocab_size: int,
):
# NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
# to save a unnecessary CPU-to-CPU copy.
self.logits_indices = UvaBufferPool(max_num_logits, torch.int32)
self.grammar_bitmask = UvaBufferPool(
(max_num_logits, cdiv(vocab_size, 32)), torch.int32
)
batch_size = logits.shape[0]
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
# logits -> bitmask mapping
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
input_buffers.bitmask_indices.np[:batch_size] = mapping
input_buffers.bitmask_indices.copy_to_gpu(batch_size)
def apply_grammar_bitmask(
self,
logits: torch.Tensor,
input_batch: InputBatch,
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
) -> None:
if not grammar_req_ids:
return
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
input_buffers.grammar_bitmask.gpu,
input_buffers.grammar_bitmask.gpu.stride(0),
input_buffers.bitmask_indices.gpu,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Construct bitmask -> logits mapping
mapping: list[int] = []
req_ids = input_batch.req_ids
cu_num_logits = input_batch.cu_num_logits_np.tolist()
req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
for grammar_req_id in grammar_req_ids:
req_idx = req_id_to_idx[grammar_req_id]
logits_start_idx = cu_num_logits[req_idx]
logits_end_idx = cu_num_logits[req_idx + 1]
mapping.extend(range(logits_start_idx, logits_end_idx))
# Copy the mapping.
mapping_np = np.array(mapping, dtype=np.int32)
logits_indices = self.logits_indices.copy_to_uva(mapping_np)
# Copy the bitmask.
bitmask = self.grammar_bitmask.copy_to_uva(grammar_bitmask)
num_masks = bitmask.shape[0]
assert num_masks == len(mapping)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
logits_indices,
bitmask,
bitmask.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Adapted from
@@ -44,17 +71,14 @@ def apply_grammar_bitmask(
def _apply_grammar_bitmask_kernel(
logits_ptr,
logits_stride,
logits_indices_ptr,
bitmask_ptr,
bitmask_stride,
bitmask_indices_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
logits_idx = tl.program_id(0)
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
if bitmask_idx == -1:
# No bitmask to apply.
return
bitmask_idx = tl.program_id(0)
logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
# Load the bitmask.
block_id = tl.program_id(1)