[Model Runner V2] Remove async barrier (#32083)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -6,9 +6,8 @@ import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_uva_available
|
||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class BlockTables:
|
||||
@@ -26,19 +25,16 @@ class BlockTables:
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
if not is_uva_available():
|
||||
raise RuntimeError("UVA is not available")
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[UvaBuffer] = []
|
||||
self.block_tables: list[StagedWriteTensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
block_table = UvaBuffer(
|
||||
self.max_num_reqs,
|
||||
max_num_blocks,
|
||||
block_table = StagedWriteTensor(
|
||||
(self.max_num_reqs, max_num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_tables.append(block_table)
|
||||
self.block_table_ptrs = self._make_ptr_tensor(
|
||||
@@ -53,9 +49,8 @@ class BlockTables:
|
||||
self.block_sizes_tensor = torch.tensor(
|
||||
self.block_sizes, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.num_blocks = UvaBuffer(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_reqs,
|
||||
self.num_blocks = UvaBackedTensor(
|
||||
(self.num_kv_cache_groups, self.max_num_reqs),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
@@ -75,13 +70,11 @@ class BlockTables:
|
||||
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
|
||||
ptrs_tensor_cpu = torch.tensor(
|
||||
return torch.tensor(
|
||||
[t.data_ptr() for t in x],
|
||||
dtype=torch.uint64,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
device=self.device,
|
||||
)
|
||||
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
@@ -90,19 +83,17 @@ class BlockTables:
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_ids = new_block_ids[i]
|
||||
num_new_blocks = len(block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
|
||||
# TODO(woosuk): Too many Numpy invocations. Optimize this.
|
||||
start = self.num_blocks.np[i, req_index] if not overwrite else 0
|
||||
end = start + num_new_blocks
|
||||
if num_new_blocks == 1:
|
||||
self.block_tables[i].np[req_index, start] = block_ids[0]
|
||||
else:
|
||||
self.block_tables[i].np[req_index, start:end] = block_ids
|
||||
self.num_blocks.np[i, req_index] = end
|
||||
block_ids = new_block_ids[i]
|
||||
self.block_tables[i].stage_write(req_index, start, block_ids)
|
||||
self.num_blocks.np[i, req_index] = start + len(block_ids)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
# TODO(woosuk): This can be inefficient since it launches one kernel per
|
||||
# block table. Implement a kernel to handle all block tables at once.
|
||||
for block_table in self.block_tables:
|
||||
block_table.apply_write()
|
||||
self.num_blocks.copy_to_uva()
|
||||
|
||||
def gather_block_tables(
|
||||
self,
|
||||
@@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype):
|
||||
ptr = tl.load(ptr_to_ptr)
|
||||
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
return tl.multiple_of(ptr, 16)
|
||||
|
||||
|
||||
class UvaBuffer:
|
||||
def __init__(self, *size, dtype: torch.dtype):
|
||||
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||
|
||||
218
vllm/v1/worker/gpu/buffer_utils.py
Normal file
218
vllm/v1/worker/gpu/buffer_utils.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.utils.platform_utils import is_uva_available
|
||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||
|
||||
|
||||
class UvaBuffer:
|
||||
def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
|
||||
if not is_uva_available():
|
||||
raise RuntimeError("UVA is not available")
|
||||
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
|
||||
self.np = self.cpu.numpy()
|
||||
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||
|
||||
|
||||
class UvaBufferPool:
|
||||
def __init__(
|
||||
self,
|
||||
size: int | Sequence[int],
|
||||
dtype: torch.dtype,
|
||||
max_concurrency: int = 2,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
# UVA buffers for concurrency
|
||||
self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)]
|
||||
# Current buffer index
|
||||
self._curr = 0
|
||||
|
||||
def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor:
|
||||
# Round robin to the next buffer.
|
||||
self._curr = (self._curr + 1) % self.max_concurrency
|
||||
buf = self._uva_bufs[self._curr]
|
||||
# CPU-to-CPU copy
|
||||
dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np
|
||||
n = len(x)
|
||||
dst[:n] = x
|
||||
return buf.uva[:n]
|
||||
|
||||
def copy_to_gpu(
|
||||
self,
|
||||
x: torch.Tensor | np.ndarray,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
uva = self.copy_to_uva(x)
|
||||
if out is None:
|
||||
# CPU-to-GPU copy
|
||||
return uva.clone()
|
||||
# CPU-to-GPU copy
|
||||
return out.copy_(uva, non_blocking=True)
|
||||
|
||||
|
||||
class UvaBackedTensor:
|
||||
def __init__(
|
||||
self,
|
||||
size: int | Sequence[int],
|
||||
dtype: torch.dtype,
|
||||
max_concurrency: int = 2,
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
# Source of truth
|
||||
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
|
||||
self.np = self.cpu.numpy()
|
||||
|
||||
# Buffers for concurrency
|
||||
self.pool = UvaBufferPool(size, dtype, max_concurrency)
|
||||
self.gpu = self.pool.copy_to_uva(self.np)
|
||||
|
||||
def copy_to_uva(self, n: int | None = None) -> torch.Tensor:
|
||||
# CPU-to-CPU copy
|
||||
self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np)
|
||||
return self.gpu
|
||||
|
||||
|
||||
class StagedWriteTensor:
|
||||
def __init__(
|
||||
self,
|
||||
size: int | Sequence[int],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
max_concurrency: int = 2,
|
||||
uva_instead_of_gpu: bool = False,
|
||||
):
|
||||
if dtype not in [torch.int32, torch.int64]:
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype}: should be either int32 or int64"
|
||||
)
|
||||
self.num_rows = size if isinstance(size, int) else size[0]
|
||||
self.dtype = dtype
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
if not uva_instead_of_gpu:
|
||||
# Create a GPU tensor (default)
|
||||
self.gpu = torch.zeros(size, dtype=dtype, device=device)
|
||||
else:
|
||||
# For a large but not-frequently-accessed tensor, we can use UVA instead of
|
||||
# GPU to save GPU memory
|
||||
self._uva_buf = UvaBuffer(size, dtype)
|
||||
self.gpu = self._uva_buf.uva
|
||||
|
||||
self._staged_write_indices: list[int] = []
|
||||
self._staged_write_starts: list[int] = []
|
||||
self._staged_write_contents: list[int] = []
|
||||
self._staged_write_cu_lens: list[int] = []
|
||||
|
||||
self.write_indices = UvaBufferPool(
|
||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
||||
)
|
||||
self.write_starts = UvaBufferPool(
|
||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
||||
)
|
||||
init_size = next_power_of_2(self.num_rows)
|
||||
self.write_contents = UvaBufferPool(
|
||||
init_size, dtype=dtype, max_concurrency=max_concurrency
|
||||
)
|
||||
self.write_cu_lens = UvaBufferPool(
|
||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
||||
)
|
||||
|
||||
def stage_write(self, index: int, start: int, x: list[int]) -> None:
|
||||
assert index >= 0
|
||||
assert start >= 0
|
||||
if not x:
|
||||
return
|
||||
self._staged_write_indices.append(index)
|
||||
self._staged_write_starts.append(start)
|
||||
self._staged_write_contents.extend(x)
|
||||
self._staged_write_cu_lens.append(len(self._staged_write_contents))
|
||||
|
||||
def stage_write_elem(self, index: int, x: int) -> None:
|
||||
assert index >= 0
|
||||
self._staged_write_indices.append(index)
|
||||
self._staged_write_starts.append(0)
|
||||
self._staged_write_contents.append(x)
|
||||
self._staged_write_cu_lens.append(len(self._staged_write_contents))
|
||||
|
||||
def apply_write(self) -> None:
|
||||
n = len(self._staged_write_indices)
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
|
||||
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
|
||||
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
|
||||
|
||||
# Special handling for write_contents
|
||||
diff_len = len(self._staged_write_contents)
|
||||
assert isinstance(self.write_contents.size, int)
|
||||
if diff_len > self.write_contents.size:
|
||||
# Re-allocate a larger buffer for the write_contents
|
||||
new_size = next_power_of_2(diff_len)
|
||||
self.write_contents = UvaBufferPool(
|
||||
new_size, dtype=self.dtype, max_concurrency=self.max_concurrency
|
||||
)
|
||||
# NOTE(woosuk): Since the previous write_contents buffer is released,
|
||||
# we perform a synchronization here to ensure that all data transfers
|
||||
# involving the old buffer have finished before allocating a new one.
|
||||
# This prevents potential race conditions. The slight overhead is
|
||||
# negligible because the reallocations are infrequent in practice.
|
||||
torch.cuda.synchronize()
|
||||
contents_uva = self.write_contents.copy_to_uva(self._staged_write_contents)
|
||||
|
||||
# Write diffs to the GPU buffer
|
||||
_apply_write_kernel[(n,)](
|
||||
self.gpu,
|
||||
self.gpu.stride(0),
|
||||
indices_uva,
|
||||
starts_uva,
|
||||
contents_uva,
|
||||
cu_lens_uva,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
# Clear the staged writes
|
||||
self.clear_staged_writes()
|
||||
|
||||
def clear_staged_writes(self) -> None:
|
||||
self._staged_write_indices.clear()
|
||||
self._staged_write_starts.clear()
|
||||
self._staged_write_contents.clear()
|
||||
self._staged_write_cu_lens.clear()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_write_kernel(
|
||||
output_ptr,
|
||||
output_stride,
|
||||
write_indices_ptr,
|
||||
write_starts_ptr,
|
||||
write_contents_ptr,
|
||||
write_cu_lens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
row_idx = tl.load(write_indices_ptr + pid)
|
||||
start_idx = tl.load(write_starts_ptr + pid)
|
||||
|
||||
cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0
|
||||
cu_end = tl.load(write_cu_lens_ptr + pid)
|
||||
content_len = cu_end - cu_start
|
||||
|
||||
for i in range(0, content_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < content_len
|
||||
content = tl.load(write_contents_ptr + cu_start + block, mask=mask)
|
||||
tl.store(
|
||||
output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask
|
||||
)
|
||||
@@ -228,10 +228,13 @@ def prepare_inputs_to_capture(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
query_start_loc = input_buffers.query_start_loc
|
||||
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
|
||||
query_start_loc.np[num_reqs:] = num_tokens
|
||||
query_start_loc.copy_to_gpu()
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
@@ -245,8 +248,8 @@ def prepare_inputs_to_capture(
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
|
||||
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
|
||||
@@ -8,8 +8,6 @@ import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class InputBuffers:
|
||||
@@ -21,30 +19,17 @@ class InputBuffers:
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
|
||||
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||
self.query_start_loc = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||
|
||||
# Structured outputs.
|
||||
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.grammar_bitmask = self._make_buffer(
|
||||
max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
|
||||
)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -56,6 +41,8 @@ class InputBatch:
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
# Identical to idx_mapping except for spec decoding.
|
||||
expanded_idx_mapping: torch.Tensor
|
||||
|
||||
# [num_reqs]
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
@@ -83,6 +70,7 @@ class InputBatch:
|
||||
logits_indices: torch.Tensor
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor
|
||||
cu_num_logits_np: np.ndarray
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
@@ -96,33 +84,41 @@ class InputBatch:
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
|
||||
input_buffers.query_start_loc.np[0] = 0
|
||||
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
|
||||
num_scheduled_tokens
|
||||
)
|
||||
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
|
||||
# seq_len equals to query_len
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
|
||||
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
|
||||
# Pad for full CUDA graph mode.
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
seq_lens = input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
|
||||
input_buffers.query_start_loc[0] = 0
|
||||
torch.cumsum(
|
||||
seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1]
|
||||
)
|
||||
# Pad for full CUDA graph mode.
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
expanded_idx_mapping=expanded_idx_mapping,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
@@ -135,6 +131,7 @@ class InputBatch:
|
||||
attn_metadata=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
)
|
||||
|
||||
|
||||
@@ -473,3 +470,38 @@ def post_update(
|
||||
query_start_loc,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _expand_idx_mapping_kernel(
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_tokens
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
|
||||
|
||||
|
||||
def expand_idx_mapping(
|
||||
idx_mapping: torch.Tensor,
|
||||
total_num_logits: int,
|
||||
cu_num_logits: torch.Tensor,
|
||||
max_expand_len: int,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
|
||||
_expand_idx_mapping_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
cu_num_logits,
|
||||
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
|
||||
)
|
||||
return expanded_idx_mapping
|
||||
|
||||
@@ -15,7 +15,6 @@ from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -24,7 +23,7 @@ from vllm.v1.outputs import (
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
get_kv_cache_spec,
|
||||
@@ -32,6 +31,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
init_kv_cache,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import (
|
||||
get_batch_metadata_across_dp,
|
||||
@@ -41,22 +41,20 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
combine_sampled_and_draft_tokens,
|
||||
expand_idx_mapping,
|
||||
get_num_sampled_and_rejected,
|
||||
post_update,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.sample.metadata import (
|
||||
SamplingMetadata,
|
||||
expand_sampling_metadata,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
@@ -81,7 +79,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self.kv_cache_dtype = self.dtype
|
||||
if self.cache_config.cache_dtype != "auto":
|
||||
@@ -123,7 +120,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
@@ -132,12 +128,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
vocab_size=self.vocab_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
|
||||
# Structured outputs worker.
|
||||
self.structured_outputs_worker = StructuredOutputsWorker(
|
||||
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
# Buffers for CPU-to-GPU copies.
|
||||
self.tmp_idx_mapping = UvaBufferPool(self.max_num_reqs, torch.int32)
|
||||
self.tmp_cu_num_logits = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
|
||||
self.tmp_query_start_loc = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
@@ -228,16 +233,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1]
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
@@ -396,8 +398,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
)
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.prefill_len.copy_to_gpu()
|
||||
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
@@ -409,6 +409,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_index, req_new_block_ids, overwrite=False
|
||||
)
|
||||
|
||||
self.req_states.apply_staged_writes()
|
||||
self.block_tables.apply_staged_writes()
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@@ -431,19 +434,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_list = [
|
||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
idx_mapping = self.input_buffers.idx_mapping
|
||||
idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
|
||||
idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32)
|
||||
idx_mapping = self.tmp_idx_mapping.copy_to_gpu(idx_mapping_np)
|
||||
|
||||
# Get the number of draft tokens for each request.
|
||||
if not scheduler_output.scheduled_spec_decode_tokens:
|
||||
# No draft token scheduled (common case).
|
||||
total_num_draft_tokens = 0
|
||||
total_num_logits = num_reqs
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
cu_num_logits = torch.arange(
|
||||
num_reqs + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
else:
|
||||
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_draft_tokens = np.array(
|
||||
@@ -456,44 +459,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
total_num_draft_tokens = int(num_draft_tokens.sum())
|
||||
total_num_logits = num_reqs + total_num_draft_tokens
|
||||
|
||||
np.cumsum(
|
||||
num_draft_tokens + 1,
|
||||
out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
|
||||
num_logits = num_draft_tokens + 1
|
||||
cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
|
||||
cu_num_logits_np[0] = 0
|
||||
np.cumsum(num_logits, out=cu_num_logits_np[1:])
|
||||
cu_num_logits = self.tmp_cu_num_logits.copy_to_gpu(cu_num_logits_np)
|
||||
|
||||
expanded_idx_mapping = expand_idx_mapping(
|
||||
idx_mapping,
|
||||
total_num_logits,
|
||||
cu_num_logits,
|
||||
max_expand_len=self.num_speculative_steps + 1,
|
||||
)
|
||||
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
# Get query_start_loc.
|
||||
np.cumsum(
|
||||
num_scheduled_tokens,
|
||||
out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
|
||||
)
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
|
||||
# Pad for full CUDA graph mode.
|
||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
|
||||
self.input_buffers.query_start_loc.copy_to_gpu()
|
||||
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
|
||||
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||
self.tmp_query_start_loc.copy_to_gpu(
|
||||
query_start_loc_np,
|
||||
out=self.input_buffers.query_start_loc,
|
||||
)
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# Get prefill tokens.
|
||||
prepare_prefill_inputs(
|
||||
self.input_buffers.input_ids,
|
||||
self.req_states.next_prefill_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc_gpu,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
|
||||
# Prepare positions and seq_lens.
|
||||
prepare_pos_seq_lens(
|
||||
idx_mapping,
|
||||
query_start_loc_gpu,
|
||||
self.req_states.num_computed_tokens,
|
||||
query_start_loc,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
self.input_buffers.positions,
|
||||
self.input_buffers.seq_lens,
|
||||
)
|
||||
@@ -505,7 +517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_buffers.input_ids,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.draft_tokens,
|
||||
@@ -515,7 +527,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
|
||||
query_start_loc, self.input_buffers.positions[:num_tokens]
|
||||
)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
@@ -523,7 +535,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
@@ -539,11 +551,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
expanded_idx_mapping=expanded_idx_mapping,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
num_draft_tokens=total_num_draft_tokens,
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
@@ -551,6 +564,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
)
|
||||
|
||||
def sample(
|
||||
@@ -564,16 +578,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if grammar_output is not None:
|
||||
# Apply grammar bitmask to the logits in-place.
|
||||
# TODO(woosuk): Make compatible with spec decoding.
|
||||
assert input_batch.num_draft_tokens == 0
|
||||
with async_barrier(self.structured_outputs_event):
|
||||
apply_grammar_bitmask(
|
||||
logits,
|
||||
input_batch.req_ids,
|
||||
grammar_output.structured_output_request_ids,
|
||||
grammar_output.grammar_bitmask,
|
||||
self.input_buffers,
|
||||
)
|
||||
self.structured_outputs_worker.apply_grammar_bitmask(
|
||||
logits,
|
||||
input_batch,
|
||||
grammar_output.structured_output_request_ids,
|
||||
grammar_output.grammar_bitmask,
|
||||
)
|
||||
|
||||
# Sample tokens and compute logprobs (if needed).
|
||||
sampler_output = self.sampler(logits, sampling_metadata)
|
||||
@@ -641,8 +651,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Handle chunked prompts.
|
||||
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
|
||||
is_prompt_chunked = pos_after_step < prompt_lens
|
||||
prefill_token_ids = self.req_states.prefill_token_ids.np
|
||||
query_start_loc = self.input_buffers.query_start_loc.np
|
||||
prefill_token_ids = self.req_states.prefill_token_ids.gpu
|
||||
query_start_loc_np = input_batch.query_start_loc_np
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
@@ -650,10 +660,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
continue
|
||||
# The prompt is chunked. Get the next prompt token.
|
||||
req_idx = input_batch.idx_mapping_np[i]
|
||||
next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]])
|
||||
idx = int(query_start_loc[i + 1] - 1)
|
||||
# Set the next prompt token.
|
||||
# NOTE(woosuk): This triggers a GPU operation.
|
||||
idx = int(query_start_loc_np[i + 1] - 1)
|
||||
# NOTE(woosuk): This triggers two GPU operations.
|
||||
next_prompt_token = prefill_token_ids[req_idx, pos_after_step[i]]
|
||||
token_ids[idx] = next_prompt_token
|
||||
|
||||
# NOTE(woosuk): We mask out logprobs for negative tokens.
|
||||
@@ -669,8 +678,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
|
||||
start_idx = query_start_loc[i]
|
||||
end_idx = query_start_loc[i + 1]
|
||||
start_idx = query_start_loc_np[i]
|
||||
end_idx = query_start_loc_np[i + 1]
|
||||
assert start_idx < end_idx, (
|
||||
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
||||
)
|
||||
@@ -714,7 +723,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Update the number of computed tokens.
|
||||
post_update(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
self.req_states.last_sampled_tokens,
|
||||
self.req_states.output_bin_counts,
|
||||
sampled_tokens,
|
||||
@@ -825,61 +834,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert intermediate_tensors is None
|
||||
if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
|
||||
# No need to run the model.
|
||||
with async_barrier(self.input_prep_event):
|
||||
self.update_states(scheduler_output)
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
self.update_states(scheduler_output)
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# NOTE: Call this before the async barrier so CPU all-reduce and
|
||||
# GPU execution can overlap.
|
||||
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
|
||||
self.get_cudagraph_and_dp_padding(scheduler_output)
|
||||
)
|
||||
with async_barrier(self.input_prep_event):
|
||||
self.update_states(scheduler_output)
|
||||
if num_tokens_after_padding == 0:
|
||||
# All DP ranks have zero tokens to run.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
self.update_states(scheduler_output)
|
||||
if num_tokens_after_padding == 0:
|
||||
# All DP ranks have zero tokens to run.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
if not dummy_run:
|
||||
# Common case.
|
||||
# Prepare all the inputs and copy to the input buffers.
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output,
|
||||
num_tokens_after_padding,
|
||||
)
|
||||
if not dummy_run:
|
||||
# Common case.
|
||||
# Prepare all the inputs and copy to the input buffers.
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output,
|
||||
num_tokens_after_padding,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Sampling metadata should be built under the async
|
||||
# barrier to avoid race conditions.
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
|
||||
)
|
||||
if input_batch.num_draft_tokens > 0:
|
||||
sampling_metadata = expand_sampling_metadata(
|
||||
sampling_metadata,
|
||||
input_batch.cu_num_logits,
|
||||
max_expand_len=self.num_speculative_steps + 1,
|
||||
)
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.req_states.make_lora_inputs(
|
||||
input_batch.req_ids,
|
||||
input_batch.idx_mapping_np,
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.req_states.make_lora_inputs(
|
||||
input_batch.req_ids,
|
||||
input_batch.idx_mapping_np,
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
sampling_metadata = None
|
||||
self._set_active_loras(*lora_inputs)
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
sampling_metadata = None
|
||||
|
||||
# Run model.
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
|
||||
@@ -13,6 +13,7 @@ def _gumbel_sample_kernel(
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
@@ -20,22 +21,24 @@ def _gumbel_sample_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
logits_ptr + batch_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise.
|
||||
@@ -55,12 +58,13 @@ def _gumbel_sample_kernel(
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
value = tl.max(logits, axis=0)
|
||||
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [num_reqs]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
@@ -88,6 +92,7 @@ def gumbel_sample(
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
|
||||
@@ -4,20 +4,23 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
idx_mapping: torch.Tensor
|
||||
|
||||
temperature: torch.Tensor
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
min_p: torch.Tensor | None
|
||||
|
||||
# For penalties
|
||||
repetition_penalty: torch.Tensor
|
||||
frequency_penalty: torch.Tensor
|
||||
presence_penalty: torch.Tensor
|
||||
prompt_bin_mask: torch.Tensor
|
||||
output_bin_counts: torch.Tensor
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
@@ -25,11 +28,6 @@ class SamplingMetadata:
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
# For penalties
|
||||
idx_mapping: torch.Tensor
|
||||
prompt_bin_mask: torch.Tensor
|
||||
output_bin_counts: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
@@ -37,6 +35,8 @@ class SamplingMetadata:
|
||||
device: torch.device,
|
||||
) -> "SamplingMetadata":
|
||||
assert num_reqs > 0
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
||||
@@ -51,18 +51,19 @@ class SamplingMetadata:
|
||||
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
|
||||
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
|
||||
# specialization and re-compilation at runtime.
|
||||
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
return cls(
|
||||
idx_mapping=idx_mapping,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -70,123 +71,9 @@ class SamplingMetadata:
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
prompt_bin_mask=prompt_bin_mask,
|
||||
output_bin_counts=output_bin_counts,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
idx_mapping=idx_mapping,
|
||||
prompt_bin_mask=prompt_bin_mask,
|
||||
output_bin_counts=output_bin_counts,
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
|
||||
@triton.jit
|
||||
def _expand_sampling_metadata_kernel(
|
||||
temp_ptr,
|
||||
expanded_temp_ptr,
|
||||
top_p_ptr,
|
||||
expanded_top_p_ptr,
|
||||
top_k_ptr,
|
||||
expanded_top_k_ptr,
|
||||
min_p_ptr,
|
||||
expanded_min_p_ptr,
|
||||
rep_penalty_ptr,
|
||||
expanded_rep_penalty_ptr,
|
||||
freq_penalty_ptr,
|
||||
expanded_freq_penalty_ptr,
|
||||
pres_penalty_ptr,
|
||||
expanded_pres_penalty_ptr,
|
||||
seeds_ptr,
|
||||
expanded_seeds_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_tokens
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
|
||||
|
||||
if top_p_ptr is not None:
|
||||
top_p = tl.load(top_p_ptr + req_idx)
|
||||
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
|
||||
|
||||
if top_k_ptr is not None:
|
||||
top_k = tl.load(top_k_ptr + req_idx)
|
||||
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
|
||||
|
||||
if min_p_ptr is not None:
|
||||
min_p = tl.load(min_p_ptr + req_idx)
|
||||
tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask)
|
||||
|
||||
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
|
||||
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
|
||||
|
||||
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
|
||||
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
|
||||
|
||||
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
|
||||
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
|
||||
|
||||
|
||||
def expand_sampling_metadata(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
cu_num_logits: torch.Tensor,
|
||||
max_expand_len: int,
|
||||
) -> SamplingMetadata:
|
||||
total_num_logits = sampling_metadata.pos.shape[0]
|
||||
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
|
||||
expanded_temp = create_empty(sampling_metadata.temperature)
|
||||
expanded_top_p = create_empty(sampling_metadata.top_p)
|
||||
expanded_top_k = create_empty(sampling_metadata.top_k)
|
||||
expanded_min_p = create_empty(sampling_metadata.min_p)
|
||||
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
|
||||
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
|
||||
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
|
||||
expanded_seeds = create_empty(sampling_metadata.seeds)
|
||||
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
_expand_sampling_metadata_kernel[(num_reqs,)](
|
||||
sampling_metadata.temperature,
|
||||
expanded_temp,
|
||||
sampling_metadata.top_p,
|
||||
expanded_top_p,
|
||||
sampling_metadata.top_k,
|
||||
expanded_top_k,
|
||||
sampling_metadata.min_p,
|
||||
expanded_min_p,
|
||||
sampling_metadata.repetition_penalty,
|
||||
expanded_repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
expanded_frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
expanded_presence_penalty,
|
||||
sampling_metadata.seeds,
|
||||
expanded_seeds,
|
||||
cu_num_logits,
|
||||
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
|
||||
)
|
||||
return SamplingMetadata(
|
||||
temperature=expanded_temp,
|
||||
top_p=expanded_top_p,
|
||||
top_k=expanded_top_k,
|
||||
min_p=expanded_min_p,
|
||||
seeds=expanded_seeds,
|
||||
repetition_penalty=expanded_repetition_penalty,
|
||||
frequency_penalty=expanded_frequency_penalty,
|
||||
presence_penalty=expanded_presence_penalty,
|
||||
pos=sampling_metadata.pos,
|
||||
max_num_logprobs=sampling_metadata.max_num_logprobs,
|
||||
# TODO(woosuk): Support penalties with spec decoding.
|
||||
idx_mapping=sampling_metadata.idx_mapping,
|
||||
prompt_bin_mask=sampling_metadata.prompt_bin_mask,
|
||||
output_bin_counts=sampling_metadata.output_bin_counts,
|
||||
)
|
||||
|
||||
@@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
min_p = tl.load(min_p_ptr + req_idx).to(tl.float32)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
|
||||
@@ -39,12 +41,17 @@ def _min_p_kernel(
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None:
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
|
||||
@@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
def _penalties_and_temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
repetition_penalty_ptr,
|
||||
frequency_penalty_ptr,
|
||||
presence_penalty_ptr,
|
||||
temperature_ptr,
|
||||
idx_mapping_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
@@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
||||
temperature = tl.load(temperature_ptr + batch_idx)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
temperature = tl.load(temperature_ptr + req_state_idx)
|
||||
temperature = tl.where(temperature == 0.0, 1.0, temperature)
|
||||
|
||||
use_rep_penalty = rep_penalty != 1.0
|
||||
@@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel(
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
if use_penalty:
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
output_bin_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
||||
mask=mask,
|
||||
@@ -92,11 +92,11 @@ def apply_penalties_and_temperature(
|
||||
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampling_metadata.idx_mapping,
|
||||
sampling_metadata.repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.idx_mapping,
|
||||
sampling_metadata.prompt_bin_mask,
|
||||
sampling_metadata.prompt_bin_mask.stride(0),
|
||||
sampling_metadata.output_bin_counts,
|
||||
|
||||
@@ -71,7 +71,7 @@ class Sampler:
|
||||
apply_penalties_and_temperature(logits, sampling_metadata)
|
||||
# Apply min_p in place.
|
||||
if sampling_metadata.min_p is not None:
|
||||
apply_min_p(logits, sampling_metadata.min_p)
|
||||
apply_min_p(logits, sampling_metadata.idx_mapping, sampling_metadata.min_p)
|
||||
# Apply top_k and/or top_p. This might return a new tensor.
|
||||
logits = apply_top_k_top_p(
|
||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||
@@ -79,6 +79,7 @@ class Sampler:
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
sampling_metadata.idx_mapping,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
@@ -46,7 +44,6 @@ class EagleSpeculator:
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
|
||||
self.vocab_size = self.draft_model_config.get_vocab_size()
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.input_buffers = InputBuffers(
|
||||
@@ -56,7 +53,6 @@ class EagleSpeculator:
|
||||
vocab_size=self.vocab_size,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.hidden_states = torch.zeros(
|
||||
self.max_num_tokens,
|
||||
@@ -64,6 +60,11 @@ class EagleSpeculator:
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.idx_mapping = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.temperature = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.float32,
|
||||
@@ -140,7 +141,7 @@ class EagleSpeculator:
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
) -> None:
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
for step in range(1, self.num_speculative_steps):
|
||||
# Run the eagle model.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
@@ -152,8 +153,9 @@ class EagleSpeculator:
|
||||
# used for draft and target sampling.
|
||||
draft_tokens = gumbel_sample(
|
||||
logits,
|
||||
self.temperature[:num_reqs],
|
||||
self.seeds[:num_reqs],
|
||||
self.idx_mapping[:num_reqs],
|
||||
self.temperature,
|
||||
self.seeds,
|
||||
pos + 1,
|
||||
apply_temperature=True,
|
||||
)
|
||||
@@ -237,23 +239,27 @@ class EagleSpeculator:
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
|
||||
# NOTE(woosuk): For draft sampling, we only consider the temperature
|
||||
# and ignore the other sampling parameters such as top_k and top_p,
|
||||
# for simplicity and performance.
|
||||
# While this may slightly degrade the acceptance rate, it does not
|
||||
# affect the output distribution after rejection sampling.
|
||||
temperature = self.temperature[:num_reqs]
|
||||
seeds = self.seeds[:num_reqs]
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
idx_mapping = self.idx_mapping[:num_reqs]
|
||||
idx_mapping.copy_(input_batch.idx_mapping)
|
||||
self.temperature.copy_(sampling_metadata.temperature)
|
||||
self.seeds.copy_(sampling_metadata.seeds)
|
||||
# Gather the values and copy them to the pre-allocated buffers.
|
||||
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
|
||||
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
|
||||
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||
# used for draft and target sampling.
|
||||
draft_tokens = gumbel_sample(
|
||||
logits, temperature, seeds, pos + 1, apply_temperature=True
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.temperature,
|
||||
self.seeds,
|
||||
pos + 1,
|
||||
apply_temperature=True,
|
||||
)
|
||||
if self.num_speculative_steps == 1:
|
||||
# Early exit.
|
||||
@@ -273,11 +279,8 @@ class EagleSpeculator:
|
||||
self.max_model_len,
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, pos
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
if cudagraph_size is not None:
|
||||
@@ -286,8 +289,9 @@ class EagleSpeculator:
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager mode.
|
||||
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
|
||||
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
@@ -295,7 +299,7 @@ class EagleSpeculator:
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
@@ -484,7 +488,7 @@ def prepare_eagle_decode(
|
||||
input_buffers.positions,
|
||||
input_hidden_states,
|
||||
input_hidden_states.stride(0),
|
||||
input_buffers.query_start_loc.gpu,
|
||||
input_buffers.query_start_loc,
|
||||
input_buffers.seq_lens,
|
||||
hidden_size,
|
||||
max_model_len,
|
||||
|
||||
@@ -8,10 +8,8 @@ import torch
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_uva_available
|
||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.penalties import bincount
|
||||
|
||||
@@ -29,7 +27,6 @@ class RequestState:
|
||||
num_speculative_steps: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
@@ -37,7 +34,6 @@ class RequestState:
|
||||
self.num_speculative_steps = num_speculative_steps
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
@@ -47,16 +43,18 @@ class RequestState:
|
||||
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# depending on the configured max_num_reqs and max_model_len.
|
||||
self.prefill_token_ids = UvaBuffer(
|
||||
self.max_num_reqs, self.max_model_len, dtype=torch.int32
|
||||
# To save GPU memory, we use UVA instead of GPU for this tensor.
|
||||
self.prefill_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
# NOTE(woosuk): We don't use UVA for prefill_len because its GPU view
|
||||
# can be used outside of update_states and prepare_inputs.
|
||||
# Without async barrier, using UVA can cause race conditions.
|
||||
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
|
||||
# Number of computed tokens.
|
||||
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = torch.zeros(
|
||||
self.num_computed_tokens = StagedWriteTensor(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
@@ -84,14 +82,16 @@ class RequestState:
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
|
||||
self.min_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
|
||||
self.temperature = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
|
||||
self.top_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
|
||||
self.top_k = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.min_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
|
||||
self.repetition_penalty = UvaBackedTensor(
|
||||
self.max_num_reqs, dtype=torch.float32
|
||||
)
|
||||
self.frequency_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
|
||||
self.presence_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
|
||||
self.seeds = UvaBackedTensor(self.max_num_reqs, dtype=torch.int64)
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
@@ -111,13 +111,7 @@ class RequestState:
|
||||
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
|
||||
)
|
||||
self._penalties_reqs: list[int] = []
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
@@ -144,12 +138,9 @@ class RequestState:
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids
|
||||
|
||||
self.prefill_token_ids.stage_write(req_idx, 0, prefill_token_ids)
|
||||
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
|
||||
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
|
||||
# Optimize this.
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
|
||||
|
||||
if lora_request is not None:
|
||||
self.lora_ids[req_idx] = lora_request.lora_int_id
|
||||
@@ -169,13 +160,7 @@ class RequestState:
|
||||
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
|
||||
|
||||
if use_penalty(sampling_params):
|
||||
bincount(
|
||||
self.prefill_token_ids.gpu[req_idx],
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
self.prompt_bin_mask[req_idx],
|
||||
self.output_bin_counts[req_idx],
|
||||
)
|
||||
self._penalties_reqs.append(req_idx)
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
@@ -193,6 +178,22 @@ class RequestState:
|
||||
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prefill_len.copy_to_uva()
|
||||
self.prefill_token_ids.apply_write()
|
||||
self.num_computed_tokens.apply_write()
|
||||
|
||||
# TODO(woosuk): Optimize this.
|
||||
for req_idx in self._penalties_reqs:
|
||||
bincount(
|
||||
self.prefill_token_ids.gpu[req_idx],
|
||||
int(self.prefill_len.np[req_idx]),
|
||||
int(self.prompt_len[req_idx]),
|
||||
self.prompt_bin_mask[req_idx],
|
||||
self.output_bin_counts[req_idx],
|
||||
)
|
||||
self._penalties_reqs.clear()
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.extra_data.pop(req_id, None)
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
@@ -208,30 +209,25 @@ class RequestState:
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature.np[idx_mapping_np]
|
||||
temperature = self.temperature.copy_np_to_gpu(temperature)
|
||||
temperature = self.temperature.copy_to_uva()
|
||||
|
||||
top_p = self.top_p.np[idx_mapping_np]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
top_p = self.top_p.copy_to_uva()[idx_mapping] if not no_top_p else None
|
||||
|
||||
top_k = self.top_k.np[idx_mapping_np]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
top_k = self.top_k.copy_to_uva()[idx_mapping] if not no_top_k else None
|
||||
|
||||
min_p = self.min_p.np[idx_mapping_np]
|
||||
no_min_p = np.all(min_p == 0.0)
|
||||
min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None
|
||||
min_p = self.min_p.copy_to_uva() if not no_min_p else None
|
||||
|
||||
rep_penalty = self.repetition_penalty.np[idx_mapping_np]
|
||||
rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty)
|
||||
freq_penalty = self.frequency_penalty.np[idx_mapping_np]
|
||||
freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty)
|
||||
pres_penalty = self.presence_penalty.np[idx_mapping_np]
|
||||
pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty)
|
||||
rep_penalty = self.repetition_penalty.copy_to_uva()
|
||||
freq_penalty = self.frequency_penalty.copy_to_uva()
|
||||
pres_penalty = self.presence_penalty.copy_to_uva()
|
||||
|
||||
seeds = self.seeds.np[idx_mapping_np]
|
||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||
seeds = self.seeds.copy_to_uva()
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping_np]
|
||||
max_num_logprobs: int | None = int(np.max(num_logprobs))
|
||||
@@ -239,6 +235,7 @@ class RequestState:
|
||||
max_num_logprobs = None
|
||||
|
||||
return SamplingMetadata(
|
||||
idx_mapping=idx_mapping,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -246,12 +243,11 @@ class RequestState:
|
||||
repetition_penalty=rep_penalty,
|
||||
frequency_penalty=freq_penalty,
|
||||
presence_penalty=pres_penalty,
|
||||
prompt_bin_mask=self.prompt_bin_mask,
|
||||
output_bin_counts=self.output_bin_counts,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
idx_mapping=idx_mapping,
|
||||
prompt_bin_mask=self.prompt_bin_mask,
|
||||
output_bin_counts=self.output_bin_counts,
|
||||
)
|
||||
|
||||
def make_lora_inputs(
|
||||
@@ -272,42 +268,12 @@ class RequestState:
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
|
||||
class Param:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.buffer = CpuGpuBuffer(
|
||||
size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.np = np.zeros_like(self.buffer.np)
|
||||
|
||||
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
self.buffer.np[:n] = x
|
||||
return self.buffer.copy_to_gpu(n)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraData:
|
||||
lora_request: LoRARequest | None
|
||||
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
|
||||
|
||||
|
||||
class UvaBuffer:
|
||||
def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype):
|
||||
assert is_uva_available()
|
||||
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||
|
||||
|
||||
def use_penalty(sampling_params: SamplingParams) -> bool:
|
||||
return (
|
||||
sampling_params.repetition_penalty != 1.0
|
||||
|
||||
@@ -4,38 +4,65 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
logits: torch.Tensor,
|
||||
req_ids: list[str],
|
||||
grammar_req_ids: list[str],
|
||||
grammar_bitmask: np.ndarray,
|
||||
input_buffers: InputBuffers,
|
||||
) -> None:
|
||||
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
|
||||
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
|
||||
class StructuredOutputsWorker:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_logits: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
# NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
|
||||
# to save a unnecessary CPU-to-CPU copy.
|
||||
self.logits_indices = UvaBufferPool(max_num_logits, torch.int32)
|
||||
self.grammar_bitmask = UvaBufferPool(
|
||||
(max_num_logits, cdiv(vocab_size, 32)), torch.int32
|
||||
)
|
||||
|
||||
batch_size = logits.shape[0]
|
||||
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
|
||||
# logits -> bitmask mapping
|
||||
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
|
||||
input_buffers.bitmask_indices.np[:batch_size] = mapping
|
||||
input_buffers.bitmask_indices.copy_to_gpu(batch_size)
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
grammar_req_ids: list[str],
|
||||
grammar_bitmask: np.ndarray,
|
||||
) -> None:
|
||||
if not grammar_req_ids:
|
||||
return
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
BLOCK_SIZE = 8192
|
||||
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
|
||||
_apply_grammar_bitmask_kernel[grid](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
input_buffers.grammar_bitmask.gpu,
|
||||
input_buffers.grammar_bitmask.gpu.stride(0),
|
||||
input_buffers.bitmask_indices.gpu,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
# Construct bitmask -> logits mapping
|
||||
mapping: list[int] = []
|
||||
req_ids = input_batch.req_ids
|
||||
cu_num_logits = input_batch.cu_num_logits_np.tolist()
|
||||
req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
for grammar_req_id in grammar_req_ids:
|
||||
req_idx = req_id_to_idx[grammar_req_id]
|
||||
logits_start_idx = cu_num_logits[req_idx]
|
||||
logits_end_idx = cu_num_logits[req_idx + 1]
|
||||
mapping.extend(range(logits_start_idx, logits_end_idx))
|
||||
# Copy the mapping.
|
||||
mapping_np = np.array(mapping, dtype=np.int32)
|
||||
logits_indices = self.logits_indices.copy_to_uva(mapping_np)
|
||||
|
||||
# Copy the bitmask.
|
||||
bitmask = self.grammar_bitmask.copy_to_uva(grammar_bitmask)
|
||||
|
||||
num_masks = bitmask.shape[0]
|
||||
assert num_masks == len(mapping)
|
||||
vocab_size = logits.shape[-1]
|
||||
BLOCK_SIZE = 8192
|
||||
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
|
||||
_apply_grammar_bitmask_kernel[grid](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
logits_indices,
|
||||
bitmask,
|
||||
bitmask.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from
|
||||
@@ -44,17 +71,14 @@ def apply_grammar_bitmask(
|
||||
def _apply_grammar_bitmask_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
logits_indices_ptr,
|
||||
bitmask_ptr,
|
||||
bitmask_stride,
|
||||
bitmask_indices_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
logits_idx = tl.program_id(0)
|
||||
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
|
||||
if bitmask_idx == -1:
|
||||
# No bitmask to apply.
|
||||
return
|
||||
bitmask_idx = tl.program_id(0)
|
||||
logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
|
||||
|
||||
# Load the bitmask.
|
||||
block_id = tl.program_id(1)
|
||||
|
||||
Reference in New Issue
Block a user