diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index 9261ff4da..d45917d4b 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -6,9 +6,8 @@ import torch from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv -from vllm.utils.platform_utils import is_uva_available -from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor class BlockTables: @@ -26,19 +25,16 @@ class BlockTables: self.max_model_len = max_model_len self.device = device - if not is_uva_available(): - raise RuntimeError("UVA is not available") - self.num_kv_cache_groups = len(self.block_sizes) # num_kv_cache_groups x [max_num_reqs, max_num_blocks] - self.block_tables: list[UvaBuffer] = [] + self.block_tables: list[StagedWriteTensor] = [] for i in range(self.num_kv_cache_groups): block_size = self.block_sizes[i] max_num_blocks = cdiv(self.max_model_len, block_size) - block_table = UvaBuffer( - self.max_num_reqs, - max_num_blocks, + block_table = StagedWriteTensor( + (self.max_num_reqs, max_num_blocks), dtype=torch.int32, + device=device, ) self.block_tables.append(block_table) self.block_table_ptrs = self._make_ptr_tensor( @@ -53,9 +49,8 @@ class BlockTables: self.block_sizes_tensor = torch.tensor( self.block_sizes, dtype=torch.int32, device=self.device ) - self.num_blocks = UvaBuffer( - self.num_kv_cache_groups, - self.max_num_reqs, + self.num_blocks = UvaBackedTensor( + (self.num_kv_cache_groups, self.max_num_reqs), dtype=torch.int32, ) @@ -75,13 +70,11 @@ class BlockTables: def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses. - ptrs_tensor_cpu = torch.tensor( + return torch.tensor( [t.data_ptr() for t in x], dtype=torch.uint64, - device="cpu", - pin_memory=True, + device=self.device, ) - return ptrs_tensor_cpu.to(self.device, non_blocking=True) def append_block_ids( self, @@ -90,19 +83,17 @@ class BlockTables: overwrite: bool, ) -> None: for i in range(self.num_kv_cache_groups): - block_ids = new_block_ids[i] - num_new_blocks = len(block_ids) - if num_new_blocks == 0: - continue - - # TODO(woosuk): Too many Numpy invocations. Optimize this. start = self.num_blocks.np[i, req_index] if not overwrite else 0 - end = start + num_new_blocks - if num_new_blocks == 1: - self.block_tables[i].np[req_index, start] = block_ids[0] - else: - self.block_tables[i].np[req_index, start:end] = block_ids - self.num_blocks.np[i, req_index] = end + block_ids = new_block_ids[i] + self.block_tables[i].stage_write(req_index, start, block_ids) + self.num_blocks.np[i, req_index] = start + len(block_ids) + + def apply_staged_writes(self) -> None: + # TODO(woosuk): This can be inefficient since it launches one kernel per + # block table. Implement a kernel to handle all block tables at once. + for block_table in self.block_tables: + block_table.apply_write() + self.num_blocks.copy_to_uva() def gather_block_tables( self, @@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype): ptr = tl.load(ptr_to_ptr) ptr = tl.cast(ptr, tl.pointer_type(elem_dtype)) return tl.multiple_of(ptr, 16) - - -class UvaBuffer: - def __init__(self, *size, dtype: torch.dtype): - self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True) - self.np = self.cpu.numpy() - self.gpu = get_cuda_view_from_cpu_tensor(self.cpu) diff --git a/vllm/v1/worker/gpu/buffer_utils.py b/vllm/v1/worker/gpu/buffer_utils.py new file mode 100644 index 000000000..0662f0278 --- /dev/null +++ b/vllm/v1/worker/gpu/buffer_utils.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence + +import numpy as np +import torch + +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import next_power_of_2 +from vllm.utils.platform_utils import is_uva_available +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor + + +class UvaBuffer: + def __init__(self, size: int | Sequence[int], dtype: torch.dtype): + if not is_uva_available(): + raise RuntimeError("UVA is not available") + self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True) + self.np = self.cpu.numpy() + self.uva = get_cuda_view_from_cpu_tensor(self.cpu) + + +class UvaBufferPool: + def __init__( + self, + size: int | Sequence[int], + dtype: torch.dtype, + max_concurrency: int = 2, + ): + self.size = size + self.dtype = dtype + self.max_concurrency = max_concurrency + + # UVA buffers for concurrency + self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)] + # Current buffer index + self._curr = 0 + + def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor: + # Round robin to the next buffer. + self._curr = (self._curr + 1) % self.max_concurrency + buf = self._uva_bufs[self._curr] + # CPU-to-CPU copy + dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np + n = len(x) + dst[:n] = x + return buf.uva[:n] + + def copy_to_gpu( + self, + x: torch.Tensor | np.ndarray, + out: torch.Tensor | None = None, + ) -> torch.Tensor: + uva = self.copy_to_uva(x) + if out is None: + # CPU-to-GPU copy + return uva.clone() + # CPU-to-GPU copy + return out.copy_(uva, non_blocking=True) + + +class UvaBackedTensor: + def __init__( + self, + size: int | Sequence[int], + dtype: torch.dtype, + max_concurrency: int = 2, + ): + self.dtype = dtype + self.max_concurrency = max_concurrency + + # Source of truth + self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False) + self.np = self.cpu.numpy() + + # Buffers for concurrency + self.pool = UvaBufferPool(size, dtype, max_concurrency) + self.gpu = self.pool.copy_to_uva(self.np) + + def copy_to_uva(self, n: int | None = None) -> torch.Tensor: + # CPU-to-CPU copy + self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np) + return self.gpu + + +class StagedWriteTensor: + def __init__( + self, + size: int | Sequence[int], + dtype: torch.dtype, + device: torch.device, + max_concurrency: int = 2, + uva_instead_of_gpu: bool = False, + ): + if dtype not in [torch.int32, torch.int64]: + raise ValueError( + f"Unsupported dtype {dtype}: should be either int32 or int64" + ) + self.num_rows = size if isinstance(size, int) else size[0] + self.dtype = dtype + self.max_concurrency = max_concurrency + + if not uva_instead_of_gpu: + # Create a GPU tensor (default) + self.gpu = torch.zeros(size, dtype=dtype, device=device) + else: + # For a large but not-frequently-accessed tensor, we can use UVA instead of + # GPU to save GPU memory + self._uva_buf = UvaBuffer(size, dtype) + self.gpu = self._uva_buf.uva + + self._staged_write_indices: list[int] = [] + self._staged_write_starts: list[int] = [] + self._staged_write_contents: list[int] = [] + self._staged_write_cu_lens: list[int] = [] + + self.write_indices = UvaBufferPool( + self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency + ) + self.write_starts = UvaBufferPool( + self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency + ) + init_size = next_power_of_2(self.num_rows) + self.write_contents = UvaBufferPool( + init_size, dtype=dtype, max_concurrency=max_concurrency + ) + self.write_cu_lens = UvaBufferPool( + self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency + ) + + def stage_write(self, index: int, start: int, x: list[int]) -> None: + assert index >= 0 + assert start >= 0 + if not x: + return + self._staged_write_indices.append(index) + self._staged_write_starts.append(start) + self._staged_write_contents.extend(x) + self._staged_write_cu_lens.append(len(self._staged_write_contents)) + + def stage_write_elem(self, index: int, x: int) -> None: + assert index >= 0 + self._staged_write_indices.append(index) + self._staged_write_starts.append(0) + self._staged_write_contents.append(x) + self._staged_write_cu_lens.append(len(self._staged_write_contents)) + + def apply_write(self) -> None: + n = len(self._staged_write_indices) + if n == 0: + return + + indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices) + starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts) + cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens) + + # Special handling for write_contents + diff_len = len(self._staged_write_contents) + assert isinstance(self.write_contents.size, int) + if diff_len > self.write_contents.size: + # Re-allocate a larger buffer for the write_contents + new_size = next_power_of_2(diff_len) + self.write_contents = UvaBufferPool( + new_size, dtype=self.dtype, max_concurrency=self.max_concurrency + ) + # NOTE(woosuk): Since the previous write_contents buffer is released, + # we perform a synchronization here to ensure that all data transfers + # involving the old buffer have finished before allocating a new one. + # This prevents potential race conditions. The slight overhead is + # negligible because the reallocations are infrequent in practice. + torch.cuda.synchronize() + contents_uva = self.write_contents.copy_to_uva(self._staged_write_contents) + + # Write diffs to the GPU buffer + _apply_write_kernel[(n,)]( + self.gpu, + self.gpu.stride(0), + indices_uva, + starts_uva, + contents_uva, + cu_lens_uva, + BLOCK_SIZE=1024, + ) + # Clear the staged writes + self.clear_staged_writes() + + def clear_staged_writes(self) -> None: + self._staged_write_indices.clear() + self._staged_write_starts.clear() + self._staged_write_contents.clear() + self._staged_write_cu_lens.clear() + + +@triton.jit +def _apply_write_kernel( + output_ptr, + output_stride, + write_indices_ptr, + write_starts_ptr, + write_contents_ptr, + write_cu_lens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + row_idx = tl.load(write_indices_ptr + pid) + start_idx = tl.load(write_starts_ptr + pid) + + cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0 + cu_end = tl.load(write_cu_lens_ptr + pid) + content_len = cu_end - cu_start + + for i in range(0, content_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < content_len + content = tl.load(write_contents_ptr + cu_start + block, mask=mask) + tl.store( + output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask + ) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index a84d8e2b9..a7c20ec8b 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -228,10 +228,13 @@ def prepare_inputs_to_capture( kv_cache_config: KVCacheConfig, ) -> dict[str, Any]: num_tokens_per_req = num_tokens // num_reqs - query_start_loc = input_buffers.query_start_loc - query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req - query_start_loc.np[num_reqs:] = num_tokens - query_start_loc.copy_to_gpu() + + query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req + query_start_loc_np[-1] = num_tokens + query_start_loc_cpu = torch.from_numpy(query_start_loc_np) + input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu + input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens + query_start_loc = input_buffers.query_start_loc[: num_reqs + 1] # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens # rather than max_model_len. @@ -245,8 +248,8 @@ def prepare_inputs_to_capture( attn_metadata_builders=attn_metadata_builders, num_reqs=num_reqs, num_tokens=num_tokens, - query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1], - query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1], + query_start_loc_gpu=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, seq_lens=input_buffers.seq_lens, max_seq_len=max_model_len, block_tables=input_block_tables, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index f158eef09..78889d2ad 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -8,8 +8,6 @@ import torch from vllm.triton_utils import tl, triton from vllm.utils import random_uuid -from vllm.utils.math_utils import cdiv -from vllm.v1.utils import CpuGpuBuffer class InputBuffers: @@ -21,30 +19,17 @@ class InputBuffers: vocab_size: int, dtype: torch.dtype, device: torch.device, - pin_memory: bool, ): self.max_num_reqs = max_num_reqs self.max_num_tokens = max_num_tokens self.device = device - self.pin_memory = pin_memory - self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device) self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device) - self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) + self.query_start_loc = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) - self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) - - # Structured outputs. - self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) - self.grammar_bitmask = self._make_buffer( - max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32 - ) - - def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer( - *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device - ) @dataclass @@ -56,6 +41,8 @@ class InputBatch: # batch_idx -> req_state_idx idx_mapping: torch.Tensor idx_mapping_np: np.ndarray + # Identical to idx_mapping except for spec decoding. + expanded_idx_mapping: torch.Tensor # [num_reqs] # batch_idx -> num_scheduled_tokens @@ -83,6 +70,7 @@ class InputBatch: logits_indices: torch.Tensor # [num_reqs + 1] cu_num_logits: torch.Tensor + cu_num_logits_np: np.ndarray @classmethod def make_dummy( @@ -96,33 +84,41 @@ class InputBatch: req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)] idx_mapping_np = np.arange(num_reqs, dtype=np.int32) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) + expanded_idx_mapping = idx_mapping num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) num_scheduled_tokens[-1] += num_tokens % num_reqs assert int(num_scheduled_tokens.sum()) == num_tokens - input_buffers.query_start_loc.np[0] = 0 - input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum( - num_scheduled_tokens - ) - input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens - query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1] - query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1] # seq_len equals to query_len input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs + # Pad for full CUDA graph mode. input_buffers.seq_lens[num_reqs:] = 0 seq_lens = input_buffers.seq_lens[:num_reqs] + query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32) + query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) + input_buffers.query_start_loc[0] = 0 + torch.cumsum( + seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1] + ) + # Pad for full CUDA graph mode. + input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens + query_start_loc = input_buffers.query_start_loc[: num_reqs + 1] + input_ids = input_buffers.input_ids[:num_tokens] positions = input_buffers.positions[:num_tokens] # attn_metadata = defaultdict(lambda: None) logits_indices = query_start_loc[1:] - 1 cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32) + cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32) return cls( req_ids=req_ids, num_reqs=num_reqs, idx_mapping=idx_mapping, idx_mapping_np=idx_mapping_np, + expanded_idx_mapping=expanded_idx_mapping, num_scheduled_tokens=num_scheduled_tokens, num_tokens=num_tokens, num_tokens_after_padding=num_tokens, @@ -135,6 +131,7 @@ class InputBatch: attn_metadata=None, # type: ignore logits_indices=logits_indices, cu_num_logits=cu_num_logits, + cu_num_logits_np=cu_num_logits_np, ) @@ -473,3 +470,38 @@ def post_update( query_start_loc, num_warps=1, ) + + +@triton.jit +def _expand_idx_mapping_kernel( + idx_mapping_ptr, + expanded_idx_mapping_ptr, + cu_num_logits_ptr, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + num_tokens = end_idx - start_idx + + block = tl.arange(0, BLOCK_SIZE) + mask = block < num_tokens + req_state_idx = tl.load(idx_mapping_ptr + req_idx) + tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask) + + +def expand_idx_mapping( + idx_mapping: torch.Tensor, + total_num_logits: int, + cu_num_logits: torch.Tensor, + max_expand_len: int, +) -> torch.Tensor: + num_reqs = idx_mapping.shape[0] + expanded_idx_mapping = idx_mapping.new_empty(total_num_logits) + _expand_idx_mapping_kernel[(num_reqs,)]( + idx_mapping, + expanded_idx_mapping, + cu_num_logits, + BLOCK_SIZE=triton.next_power_of_2(max_expand_len), + ) + return expanded_idx_mapping diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 20ec89657..06dc7467f 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -15,7 +15,6 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib -from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig @@ -24,7 +23,7 @@ from vllm.v1.outputs import ( LogprobsTensors, ModelRunnerOutput, ) -from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier +from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, get_kv_cache_spec, @@ -32,6 +31,7 @@ from vllm.v1.worker.gpu.attn_utils import ( init_kv_cache, ) from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.dp_utils import ( get_batch_metadata_across_dp, @@ -41,22 +41,20 @@ from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, combine_sampled_and_draft_tokens, + expand_idx_mapping, get_num_sampled_and_rejected, post_update, prepare_pos_seq_lens, prepare_prefill_inputs, ) from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs -from vllm.v1.worker.gpu.sample.metadata import ( - SamplingMetadata, - expand_sampling_metadata, -) +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.states import RequestState -from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask +from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -81,7 +79,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.observability_config = vllm_config.observability_config self.device = device - self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype self.kv_cache_dtype = self.dtype if self.cache_config.cache_dtype != "auto": @@ -123,7 +120,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_speculative_steps=self.num_speculative_steps, vocab_size=self.vocab_size, device=self.device, - pin_memory=self.pin_memory, ) self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, @@ -132,12 +128,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): vocab_size=self.vocab_size, dtype=self.dtype, device=self.device, - pin_memory=self.pin_memory, ) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) # CUDA graphs. self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device) + # Structured outputs worker. + self.structured_outputs_worker = StructuredOutputsWorker( + max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1), + vocab_size=self.vocab_size, + ) + + # Buffers for CPU-to-GPU copies. + self.tmp_idx_mapping = UvaBufferPool(self.max_num_reqs, torch.int32) + self.tmp_cu_num_logits = UvaBufferPool(self.max_num_reqs + 1, torch.int32) + self.tmp_query_start_loc = UvaBufferPool(self.max_num_reqs + 1, torch.int32) def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len @@ -228,16 +233,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): slot_mappings = self.block_tables.get_dummy_slot_mappings( input_batch.num_tokens ) - query_start_loc = self.input_buffers.query_start_loc - query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1] - query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1] attn_metadata = build_attn_metadata( attn_metadata_builders=self.attn_metadata_builders, num_reqs=input_batch.num_reqs, num_tokens=input_batch.num_tokens, - query_start_loc_gpu=query_start_loc_gpu, - query_start_loc_cpu=query_start_loc_cpu, - seq_lens=self.input_buffers.seq_lens, + query_start_loc_gpu=input_batch.query_start_loc, + query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np), + seq_lens=input_batch.seq_lens, max_seq_len=self.max_model_len, block_tables=block_tables, slot_mappings=slot_mappings, @@ -396,8 +398,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.block_tables.append_block_ids( req_index, new_req_data.block_ids, overwrite=True ) - if scheduler_output.scheduled_new_reqs: - self.req_states.prefill_len.copy_to_gpu() # Add new blocks for the existing requests. cached_reqs = scheduler_output.scheduled_cached_reqs @@ -409,6 +409,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_index, req_new_block_ids, overwrite=False ) + self.req_states.apply_staged_writes() + self.block_tables.apply_staged_writes() + def prepare_inputs( self, scheduler_output: SchedulerOutput, @@ -431,19 +434,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): idx_mapping_list = [ self.req_states.req_id_to_index[req_id] for req_id in req_ids ] - idx_mapping = self.input_buffers.idx_mapping - idx_mapping.np[:num_reqs] = idx_mapping_list - idx_mapping_np = idx_mapping.np[:num_reqs] - idx_mapping = idx_mapping.copy_to_gpu(num_reqs) + idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32) + idx_mapping = self.tmp_idx_mapping.copy_to_gpu(idx_mapping_np) # Get the number of draft tokens for each request. if not scheduler_output.scheduled_spec_decode_tokens: # No draft token scheduled (common case). total_num_draft_tokens = 0 total_num_logits = num_reqs + cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32) cu_num_logits = torch.arange( num_reqs + 1, device=self.device, dtype=torch.int32 ) + expanded_idx_mapping = idx_mapping else: draft_tokens = scheduler_output.scheduled_spec_decode_tokens num_draft_tokens = np.array( @@ -456,44 +459,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): total_num_draft_tokens = int(num_draft_tokens.sum()) total_num_logits = num_reqs + total_num_draft_tokens - np.cumsum( - num_draft_tokens + 1, - out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1], + num_logits = num_draft_tokens + 1 + cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32) + cu_num_logits_np[0] = 0 + np.cumsum(num_logits, out=cu_num_logits_np[1:]) + cu_num_logits = self.tmp_cu_num_logits.copy_to_gpu(cu_num_logits_np) + + expanded_idx_mapping = expand_idx_mapping( + idx_mapping, + total_num_logits, + cu_num_logits, + max_expand_len=self.num_speculative_steps + 1, ) - cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1) # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] block_tables = self.block_tables.gather_block_tables(idx_mapping) # Get query_start_loc. - np.cumsum( - num_scheduled_tokens, - out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1], - ) + query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) + query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1]) # Pad for full CUDA graph mode. # Some attention backends like FA3 require query_start_loc to be non-decreasing. - self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens - self.input_buffers.query_start_loc.copy_to_gpu() - query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1] - query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1] - query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1] + query_start_loc_np[num_reqs + 1 :] = num_tokens + self.tmp_query_start_loc.copy_to_gpu( + query_start_loc_np, + out=self.input_buffers.query_start_loc, + ) + query_start_loc_np = query_start_loc_np[: num_reqs + 1] + query_start_loc_cpu = torch.from_numpy(query_start_loc_np) + query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] # Get prefill tokens. prepare_prefill_inputs( self.input_buffers.input_ids, self.req_states.next_prefill_tokens, idx_mapping, - query_start_loc_gpu, + query_start_loc, self.req_states.prefill_token_ids.gpu, self.req_states.prefill_len.gpu, - self.req_states.num_computed_tokens, + self.req_states.num_computed_tokens.gpu, ) # Prepare positions and seq_lens. prepare_pos_seq_lens( idx_mapping, - query_start_loc_gpu, - self.req_states.num_computed_tokens, + query_start_loc, + self.req_states.num_computed_tokens.gpu, self.input_buffers.positions, self.input_buffers.seq_lens, ) @@ -505,7 +517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_buffers.input_ids, idx_mapping, self.req_states.last_sampled_tokens, - query_start_loc_gpu, + query_start_loc, seq_lens, self.req_states.prefill_len.gpu, self.req_states.draft_tokens, @@ -515,7 +527,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc_gpu, self.input_buffers.positions[:num_tokens] + query_start_loc, self.input_buffers.positions[:num_tokens] ) # Layer name -> attention metadata. @@ -523,7 +535,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata_builders=self.attn_metadata_builders, num_reqs=num_reqs, num_tokens=num_tokens, - query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_gpu=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=self.input_buffers.seq_lens, max_seq_len=self.max_model_len, @@ -539,11 +551,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs=num_reqs, idx_mapping=idx_mapping, idx_mapping_np=idx_mapping_np, + expanded_idx_mapping=expanded_idx_mapping, num_scheduled_tokens=num_scheduled_tokens, num_tokens=num_tokens, num_tokens_after_padding=num_tokens_after_padding, num_draft_tokens=total_num_draft_tokens, - query_start_loc=query_start_loc_gpu, + query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, input_ids=input_ids, @@ -551,6 +564,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata=attn_metadata, logits_indices=logits_indices, cu_num_logits=cu_num_logits, + cu_num_logits_np=cu_num_logits_np, ) def sample( @@ -564,16 +578,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits = self.model.compute_logits(sample_hidden_states) if grammar_output is not None: # Apply grammar bitmask to the logits in-place. - # TODO(woosuk): Make compatible with spec decoding. - assert input_batch.num_draft_tokens == 0 - with async_barrier(self.structured_outputs_event): - apply_grammar_bitmask( - logits, - input_batch.req_ids, - grammar_output.structured_output_request_ids, - grammar_output.grammar_bitmask, - self.input_buffers, - ) + self.structured_outputs_worker.apply_grammar_bitmask( + logits, + input_batch, + grammar_output.structured_output_request_ids, + grammar_output.grammar_bitmask, + ) # Sample tokens and compute logprobs (if needed). sampler_output = self.sampler(logits, sampling_metadata) @@ -641,8 +651,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Handle chunked prompts. pos_after_step = computed_prefill + input_batch.num_scheduled_tokens is_prompt_chunked = pos_after_step < prompt_lens - prefill_token_ids = self.req_states.prefill_token_ids.np - query_start_loc = self.input_buffers.query_start_loc.np + prefill_token_ids = self.req_states.prefill_token_ids.gpu + query_start_loc_np = input_batch.query_start_loc_np for i, req_id in enumerate(input_batch.req_ids): if not needs_prompt_logprobs[i]: continue @@ -650,10 +660,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue # The prompt is chunked. Get the next prompt token. req_idx = input_batch.idx_mapping_np[i] - next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]]) - idx = int(query_start_loc[i + 1] - 1) - # Set the next prompt token. - # NOTE(woosuk): This triggers a GPU operation. + idx = int(query_start_loc_np[i + 1] - 1) + # NOTE(woosuk): This triggers two GPU operations. + next_prompt_token = prefill_token_ids[req_idx, pos_after_step[i]] token_ids[idx] = next_prompt_token # NOTE(woosuk): We mask out logprobs for negative tokens. @@ -669,8 +678,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not needs_prompt_logprobs[i]: continue - start_idx = query_start_loc[i] - end_idx = query_start_loc[i + 1] + start_idx = query_start_loc_np[i] + end_idx = query_start_loc_np[i + 1] assert start_idx < end_idx, ( f"start_idx ({start_idx}) >= end_idx ({end_idx})" ) @@ -714,7 +723,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update the number of computed tokens. post_update( input_batch.idx_mapping, - self.req_states.num_computed_tokens, + self.req_states.num_computed_tokens.gpu, self.req_states.last_sampled_tokens, self.req_states.output_bin_counts, sampled_tokens, @@ -825,61 +834,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert intermediate_tensors is None if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run: # No need to run the model. - with async_barrier(self.input_prep_event): - self.update_states(scheduler_output) - return EMPTY_MODEL_RUNNER_OUTPUT + self.update_states(scheduler_output) + return EMPTY_MODEL_RUNNER_OUTPUT - # NOTE: Call this before the async barrier so CPU all-reduce and - # GPU execution can overlap. cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = ( self.get_cudagraph_and_dp_padding(scheduler_output) ) - with async_barrier(self.input_prep_event): - self.update_states(scheduler_output) - if num_tokens_after_padding == 0: - # All DP ranks have zero tokens to run. - return EMPTY_MODEL_RUNNER_OUTPUT + self.update_states(scheduler_output) + if num_tokens_after_padding == 0: + # All DP ranks have zero tokens to run. + return EMPTY_MODEL_RUNNER_OUTPUT - if not dummy_run: - # Common case. - # Prepare all the inputs and copy to the input buffers. - input_batch = self.prepare_inputs( - scheduler_output, - num_tokens_after_padding, - ) + if not dummy_run: + # Common case. + # Prepare all the inputs and copy to the input buffers. + input_batch = self.prepare_inputs( + scheduler_output, + num_tokens_after_padding, + ) - # NOTE(woosuk): Sampling metadata should be built under the async - # barrier to avoid race conditions. - pos = input_batch.positions[input_batch.logits_indices] - sampling_metadata = self.req_states.make_sampling_metadata( - input_batch.idx_mapping, input_batch.idx_mapping_np, pos - ) - if input_batch.num_draft_tokens > 0: - sampling_metadata = expand_sampling_metadata( - sampling_metadata, - input_batch.cu_num_logits, - max_expand_len=self.num_speculative_steps + 1, - ) + pos = input_batch.positions[input_batch.logits_indices] + sampling_metadata = self.req_states.make_sampling_metadata( + input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos + ) - if self.lora_config: - # Activate LoRA adapters. - lora_inputs = self.req_states.make_lora_inputs( - input_batch.req_ids, - input_batch.idx_mapping_np, - input_batch.num_scheduled_tokens, - ) - self._set_active_loras(*lora_inputs) - else: - # No actual tokens to run. A dummy run for DP. - num_reqs = min(num_tokens_after_padding, self.max_num_reqs) - input_batch = InputBatch.make_dummy( - num_reqs=num_reqs, - num_tokens=num_tokens_after_padding, - input_buffers=self.input_buffers, - device=self.device, + if self.lora_config: + # Activate LoRA adapters. + lora_inputs = self.req_states.make_lora_inputs( + input_batch.req_ids, + input_batch.idx_mapping_np, + input_batch.num_scheduled_tokens, ) - self.prepare_dummy_attn_metadata(input_batch) - sampling_metadata = None + self._set_active_loras(*lora_inputs) + else: + # No actual tokens to run. A dummy run for DP. + num_reqs = min(num_tokens_after_padding, self.max_num_reqs) + input_batch = InputBatch.make_dummy( + num_reqs=num_reqs, + num_tokens=num_tokens_after_padding, + input_buffers=self.input_buffers, + device=self.device, + ) + self.prepare_dummy_attn_metadata(input_batch) + sampling_metadata = None # Run model. if cudagraph_mode == CUDAGraphMode.FULL: diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index a95bf1e7a..0cb80f833 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -13,6 +13,7 @@ def _gumbel_sample_kernel( local_max_stride, logits_ptr, logits_stride, + idx_mapping_ptr, seeds_ptr, pos_ptr, temp_ptr, @@ -20,22 +21,24 @@ def _gumbel_sample_kernel( BLOCK_SIZE: tl.constexpr, APPLY_TEMPERATURE: tl.constexpr, ): - req_idx = tl.program_id(0) + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + block_idx = tl.program_id(1) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = block < vocab_size logits = tl.load( - logits_ptr + req_idx * logits_stride + block, + logits_ptr + batch_idx * logits_stride + block, mask=mask, other=float("-inf"), ) logits = logits.to(tl.float32) - temp = tl.load(temp_ptr + req_idx).to(tl.float32) + temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) if temp != 0.0: # Calculate the seed for gumbel noise. - seed = tl.load(seeds_ptr + req_idx) - pos = tl.load(pos_ptr + req_idx) + seed = tl.load(seeds_ptr + req_state_idx) + pos = tl.load(pos_ptr + batch_idx) gumbel_seed = tl.randint(seed, pos) # Generate gumbel noise. @@ -55,12 +58,13 @@ def _gumbel_sample_kernel( idx = tl.argmax(logits, axis=0) token_id = block_idx * BLOCK_SIZE + idx value = tl.max(logits, axis=0) - tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id) - tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value) + tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id) + tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value) def gumbel_sample( logits: torch.Tensor, # [num_reqs, vocab_size] + idx_mapping: torch.Tensor, # [num_reqs] temperature: torch.Tensor, # [num_reqs] seed: torch.Tensor, # [num_reqs] pos: torch.Tensor, # [num_reqs] @@ -88,6 +92,7 @@ def gumbel_sample( local_max.stride(0), logits, logits.stride(0), + idx_mapping, seed, pos, temperature, diff --git a/vllm/v1/worker/gpu/sample/metadata.py b/vllm/v1/worker/gpu/sample/metadata.py index f10c72049..27167fd20 100644 --- a/vllm/v1/worker/gpu/sample/metadata.py +++ b/vllm/v1/worker/gpu/sample/metadata.py @@ -4,20 +4,23 @@ from dataclasses import dataclass import torch -from vllm.triton_utils import tl, triton - @dataclass class SamplingMetadata: + idx_mapping: torch.Tensor + temperature: torch.Tensor top_p: torch.Tensor | None top_k: torch.Tensor | None min_p: torch.Tensor | None + # For penalties repetition_penalty: torch.Tensor frequency_penalty: torch.Tensor presence_penalty: torch.Tensor + prompt_bin_mask: torch.Tensor + output_bin_counts: torch.Tensor seeds: torch.Tensor pos: torch.Tensor @@ -25,11 +28,6 @@ class SamplingMetadata: # None means no logprobs, 0 means sampled token logprobs only max_num_logprobs: int | None - # For penalties - idx_mapping: torch.Tensor - prompt_bin_mask: torch.Tensor - output_bin_counts: torch.Tensor - @classmethod def make_dummy( cls, @@ -37,6 +35,8 @@ class SamplingMetadata: device: torch.device, ) -> "SamplingMetadata": assert num_reqs > 0 + idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) + temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) temperature[0] = 0.5 # TODO(woosuk): Use top-p and top-k for dummy sampler. @@ -51,18 +51,19 @@ class SamplingMetadata: repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) - seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) - pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) - max_num_logprobs = 20 - idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) # NOTE(woosuk): These are placeholder tensors to avoid None checks in the # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton # specialization and re-compilation at runtime. prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) + seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) + pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) + max_num_logprobs = 20 + return cls( + idx_mapping=idx_mapping, temperature=temperature, top_p=top_p, top_k=top_k, @@ -70,123 +71,9 @@ class SamplingMetadata: repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + prompt_bin_mask=prompt_bin_mask, + output_bin_counts=output_bin_counts, seeds=seeds, pos=pos, max_num_logprobs=max_num_logprobs, - idx_mapping=idx_mapping, - prompt_bin_mask=prompt_bin_mask, - output_bin_counts=output_bin_counts, ) - - -# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None. -@triton.jit -def _expand_sampling_metadata_kernel( - temp_ptr, - expanded_temp_ptr, - top_p_ptr, - expanded_top_p_ptr, - top_k_ptr, - expanded_top_k_ptr, - min_p_ptr, - expanded_min_p_ptr, - rep_penalty_ptr, - expanded_rep_penalty_ptr, - freq_penalty_ptr, - expanded_freq_penalty_ptr, - pres_penalty_ptr, - expanded_pres_penalty_ptr, - seeds_ptr, - expanded_seeds_ptr, - cu_num_logits_ptr, - BLOCK_SIZE: tl.constexpr, -): - req_idx = tl.program_id(0) - start_idx = tl.load(cu_num_logits_ptr + req_idx) - end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) - num_tokens = end_idx - start_idx - - block = tl.arange(0, BLOCK_SIZE) - mask = block < num_tokens - - temp = tl.load(temp_ptr + req_idx) - tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask) - - if top_p_ptr is not None: - top_p = tl.load(top_p_ptr + req_idx) - tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask) - - if top_k_ptr is not None: - top_k = tl.load(top_k_ptr + req_idx) - tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask) - - if min_p_ptr is not None: - min_p = tl.load(min_p_ptr + req_idx) - tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask) - - rep_penalty = tl.load(rep_penalty_ptr + req_idx) - tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask) - - freq_penalty = tl.load(freq_penalty_ptr + req_idx) - tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask) - - pres_penalty = tl.load(pres_penalty_ptr + req_idx) - tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask) - - seed = tl.load(seeds_ptr + req_idx) - tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask) - - -def expand_sampling_metadata( - sampling_metadata: SamplingMetadata, - cu_num_logits: torch.Tensor, - max_expand_len: int, -) -> SamplingMetadata: - total_num_logits = sampling_metadata.pos.shape[0] - create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None - expanded_temp = create_empty(sampling_metadata.temperature) - expanded_top_p = create_empty(sampling_metadata.top_p) - expanded_top_k = create_empty(sampling_metadata.top_k) - expanded_min_p = create_empty(sampling_metadata.min_p) - expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty) - expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty) - expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty) - expanded_seeds = create_empty(sampling_metadata.seeds) - - num_reqs = cu_num_logits.shape[0] - 1 - _expand_sampling_metadata_kernel[(num_reqs,)]( - sampling_metadata.temperature, - expanded_temp, - sampling_metadata.top_p, - expanded_top_p, - sampling_metadata.top_k, - expanded_top_k, - sampling_metadata.min_p, - expanded_min_p, - sampling_metadata.repetition_penalty, - expanded_repetition_penalty, - sampling_metadata.frequency_penalty, - expanded_frequency_penalty, - sampling_metadata.presence_penalty, - expanded_presence_penalty, - sampling_metadata.seeds, - expanded_seeds, - cu_num_logits, - BLOCK_SIZE=triton.next_power_of_2(max_expand_len), - ) - return SamplingMetadata( - temperature=expanded_temp, - top_p=expanded_top_p, - top_k=expanded_top_k, - min_p=expanded_min_p, - seeds=expanded_seeds, - repetition_penalty=expanded_repetition_penalty, - frequency_penalty=expanded_frequency_penalty, - presence_penalty=expanded_presence_penalty, - pos=sampling_metadata.pos, - max_num_logprobs=sampling_metadata.max_num_logprobs, - # TODO(woosuk): Support penalties with spec decoding. - idx_mapping=sampling_metadata.idx_mapping, - prompt_bin_mask=sampling_metadata.prompt_bin_mask, - output_bin_counts=sampling_metadata.output_bin_counts, - ) diff --git a/vllm/v1/worker/gpu/sample/min_p.py b/vllm/v1/worker/gpu/sample/min_p.py index c98a42cb2..26c3e5905 100644 --- a/vllm/v1/worker/gpu/sample/min_p.py +++ b/vllm/v1/worker/gpu/sample/min_p.py @@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton def _min_p_kernel( logits_ptr, logits_stride, + idx_mapping_ptr, min_p_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) - min_p = tl.load(min_p_ptr + req_idx).to(tl.float32) + req_state_idx = tl.load(idx_mapping_ptr + req_idx) + min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32) if min_p == 0.0: return @@ -39,12 +41,17 @@ def _min_p_kernel( tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask) -def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None: +def apply_min_p( + logits: torch.Tensor, + idx_mapping: torch.Tensor, + min_p: torch.Tensor, +) -> None: num_reqs, vocab_size = logits.shape BLOCK_SIZE = 1024 _min_p_kernel[(num_reqs,)]( logits, logits.stride(0), + idx_mapping, min_p, vocab_size, BLOCK_SIZE=BLOCK_SIZE, diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index b4fcc822e..26b0346b2 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata def _penalties_and_temperature_kernel( logits_ptr, logits_stride, + idx_mapping_ptr, repetition_penalty_ptr, frequency_penalty_ptr, presence_penalty_ptr, temperature_ptr, - idx_mapping_ptr, prompt_bin_mask_ptr, prompt_bin_mask_stride, output_bin_counts_ptr, @@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel( BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) - rep_penalty = tl.load(repetition_penalty_ptr + batch_idx) - freq_penalty = tl.load(frequency_penalty_ptr + batch_idx) - pres_penalty = tl.load(presence_penalty_ptr + batch_idx) - temperature = tl.load(temperature_ptr + batch_idx) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx) + freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx) + pres_penalty = tl.load(presence_penalty_ptr + req_state_idx) + temperature = tl.load(temperature_ptr + req_state_idx) temperature = tl.where(temperature == 0.0, 1.0, temperature) use_rep_penalty = rep_penalty != 1.0 @@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel( logits = logits.to(tl.float32) if use_penalty: - req_state_idx = tl.load(idx_mapping_ptr + batch_idx) output_bin_counts = tl.load( output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, mask=mask, @@ -92,11 +92,11 @@ def apply_penalties_and_temperature( _penalties_and_temperature_kernel[(num_reqs, num_blocks)]( logits, logits.stride(0), + sampling_metadata.idx_mapping, sampling_metadata.repetition_penalty, sampling_metadata.frequency_penalty, sampling_metadata.presence_penalty, sampling_metadata.temperature, - sampling_metadata.idx_mapping, sampling_metadata.prompt_bin_mask, sampling_metadata.prompt_bin_mask.stride(0), sampling_metadata.output_bin_counts, diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 84a3e1867..6ed849ec8 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -71,7 +71,7 @@ class Sampler: apply_penalties_and_temperature(logits, sampling_metadata) # Apply min_p in place. if sampling_metadata.min_p is not None: - apply_min_p(logits, sampling_metadata.min_p) + apply_min_p(logits, sampling_metadata.idx_mapping, sampling_metadata.min_p) # Apply top_k and/or top_p. This might return a new tensor. logits = apply_top_k_top_p( logits, sampling_metadata.top_k, sampling_metadata.top_p @@ -79,6 +79,7 @@ class Sampler: sampled = gumbel_sample( logits, + sampling_metadata.idx_mapping, sampling_metadata.temperature, sampling_metadata.seeds, sampling_metadata.pos, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index 71cfaff13..ed9260120 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any -import numpy as np import torch import torch.nn as nn @@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.triton_utils import tl, triton -from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata @@ -46,7 +44,6 @@ class EagleSpeculator: self.hidden_size = self.draft_model_config.get_hidden_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() self.vocab_size = self.draft_model_config.get_vocab_size() - self.pin_memory = is_pin_memory_available() self.dtype = vllm_config.model_config.dtype self.input_buffers = InputBuffers( @@ -56,7 +53,6 @@ class EagleSpeculator: vocab_size=self.vocab_size, dtype=self.dtype, device=device, - pin_memory=self.pin_memory, ) self.hidden_states = torch.zeros( self.max_num_tokens, @@ -64,6 +60,11 @@ class EagleSpeculator: dtype=self.dtype, device=device, ) + self.idx_mapping = torch.zeros( + self.max_num_reqs, + dtype=torch.int32, + device=device, + ) self.temperature = torch.zeros( self.max_num_reqs, dtype=torch.float32, @@ -140,7 +141,7 @@ class EagleSpeculator: num_tokens_across_dp: torch.Tensor | None, ) -> None: pos = self.input_buffers.positions[:num_reqs] - query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1] + query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] for step in range(1, self.num_speculative_steps): # Run the eagle model. last_hidden_states, hidden_states = self.run_model( @@ -152,8 +153,9 @@ class EagleSpeculator: # used for draft and target sampling. draft_tokens = gumbel_sample( logits, - self.temperature[:num_reqs], - self.seeds[:num_reqs], + self.idx_mapping[:num_reqs], + self.temperature, + self.seeds, pos + 1, apply_temperature=True, ) @@ -237,23 +239,27 @@ class EagleSpeculator: logits = self.model.compute_logits(sample_hidden_states) num_reqs = input_batch.num_reqs - cu_num_logits = input_batch.cu_num_logits[:num_reqs] # NOTE(woosuk): For draft sampling, we only consider the temperature # and ignore the other sampling parameters such as top_k and top_p, # for simplicity and performance. # While this may slightly degrade the acceptance rate, it does not # affect the output distribution after rejection sampling. - temperature = self.temperature[:num_reqs] - seeds = self.seeds[:num_reqs] - pos = self.input_buffers.positions[:num_reqs] + idx_mapping = self.idx_mapping[:num_reqs] + idx_mapping.copy_(input_batch.idx_mapping) + self.temperature.copy_(sampling_metadata.temperature) + self.seeds.copy_(sampling_metadata.seeds) # Gather the values and copy them to the pre-allocated buffers. - torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature) - torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds) + pos = self.input_buffers.positions[:num_reqs] torch.gather(input_batch.positions, 0, last_token_indices, out=pos) # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # used for draft and target sampling. draft_tokens = gumbel_sample( - logits, temperature, seeds, pos + 1, apply_temperature=True + logits, + idx_mapping, + self.temperature, + self.seeds, + pos + 1, + apply_temperature=True, ) if self.num_speculative_steps == 1: # Early exit. @@ -273,11 +279,8 @@ class EagleSpeculator: self.max_model_len, self.max_num_reqs, ) - query_start_loc = self.input_buffers.query_start_loc - query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] - slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc_gpu, pos - ) + query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] + slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos) cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) if cudagraph_size is not None: @@ -286,8 +289,9 @@ class EagleSpeculator: return self.draft_tokens[:num_reqs] # Run eager mode. - query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) - query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] + query_start_loc_cpu = torch.arange( + num_reqs + 1, dtype=torch.int32, device="cpu" + ) block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] # FIXME(woosuk): This is UNSAFE!! @@ -295,7 +299,7 @@ class EagleSpeculator: attn_metadata_builders=self.attn_metadata_builders, num_reqs=num_reqs, num_tokens=num_reqs, - query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_gpu=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=self.input_buffers.seq_lens[:num_reqs], max_seq_len=self.max_model_len, @@ -484,7 +488,7 @@ def prepare_eagle_decode( input_buffers.positions, input_hidden_states, input_hidden_states.stride(0), - input_buffers.query_start_loc.gpu, + input_buffers.query_start_loc, input_buffers.seq_lens, hidden_size, max_model_len, diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 6823c0c8e..abfc88405 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -8,10 +8,8 @@ import torch from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils.math_utils import cdiv -from vllm.utils.platform_utils import is_uva_available -from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor from vllm.v1.outputs import LogprobsTensors -from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.sample.penalties import bincount @@ -29,7 +27,6 @@ class RequestState: num_speculative_steps: int, vocab_size: int, device: torch.device, - pin_memory: bool, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -37,7 +34,6 @@ class RequestState: self.num_speculative_steps = num_speculative_steps self.vocab_size = vocab_size self.device = device - self.pin_memory = pin_memory self.req_id_to_index: dict[str, int] = {} self.index_to_req_id: dict[int, str] = {} @@ -47,16 +43,18 @@ class RequestState: self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32) # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) # depending on the configured max_num_reqs and max_model_len. - self.prefill_token_ids = UvaBuffer( - self.max_num_reqs, self.max_model_len, dtype=torch.int32 + # To save GPU memory, we use UVA instead of GPU for this tensor. + self.prefill_token_ids = StagedWriteTensor( + (self.max_num_reqs, self.max_model_len), + dtype=torch.int32, + device=device, + uva_instead_of_gpu=True, ) - # NOTE(woosuk): We don't use UVA for prefill_len because its GPU view - # can be used outside of update_states and prepare_inputs. - # Without async barrier, using UVA can cause race conditions. - self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) + # Number of computed tokens. self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) - self.num_computed_tokens = torch.zeros( + self.num_computed_tokens = StagedWriteTensor( self.max_num_reqs, dtype=torch.int32, device=device ) @@ -84,14 +82,16 @@ class RequestState: self.lora_ids.fill(NO_LORA_ID) # Sampling parameters. - self.temperature = self._make_param(self.max_num_reqs, torch.float32) - self.top_p = self._make_param(self.max_num_reqs, torch.float32) - self.top_k = self._make_param(self.max_num_reqs, torch.int32) - self.min_p = self._make_param(self.max_num_reqs, torch.float32) - self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32) - self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32) - self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32) - self.seeds = self._make_param(self.max_num_reqs, torch.int64) + self.temperature = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) + self.top_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) + self.top_k = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) + self.min_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) + self.repetition_penalty = UvaBackedTensor( + self.max_num_reqs, dtype=torch.float32 + ) + self.frequency_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) + self.presence_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) + self.seeds = UvaBackedTensor(self.max_num_reqs, dtype=torch.int64) self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) # -1 means no logprobs are requested. @@ -111,13 +111,7 @@ class RequestState: self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device ) - def _make_param(self, size: int, dtype: torch.dtype) -> "Param": - return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory) - - def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer( - size, dtype=dtype, device=self.device, pin_memory=self.pin_memory - ) + self._penalties_reqs: list[int] = [] @property def num_reqs(self) -> int: @@ -144,12 +138,9 @@ class RequestState: f"prefill_len {prefill_len} < prompt_len {prompt_len}" ) self.prefill_len.np[req_idx] = prefill_len - self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids - + self.prefill_token_ids.stage_write(req_idx, 0, prefill_token_ids) self.num_computed_prefill_tokens[req_idx] = num_computed_tokens - # FIXME(woosuk): This triggers a GPU operation whenever adding a new request. - # Optimize this. - self.num_computed_tokens[req_idx] = num_computed_tokens + self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens) if lora_request is not None: self.lora_ids[req_idx] = lora_request.lora_int_id @@ -169,13 +160,7 @@ class RequestState: self.presence_penalty.np[req_idx] = sampling_params.presence_penalty if use_penalty(sampling_params): - bincount( - self.prefill_token_ids.gpu[req_idx], - prefill_len, - prompt_len, - self.prompt_bin_mask[req_idx], - self.output_bin_counts[req_idx], - ) + self._penalties_reqs.append(req_idx) if sampling_params.seed is not None: seed = sampling_params.seed @@ -193,6 +178,22 @@ class RequestState: needs_prompt_logprobs = sampling_params.prompt_logprobs is not None self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs + def apply_staged_writes(self) -> None: + self.prefill_len.copy_to_uva() + self.prefill_token_ids.apply_write() + self.num_computed_tokens.apply_write() + + # TODO(woosuk): Optimize this. + for req_idx in self._penalties_reqs: + bincount( + self.prefill_token_ids.gpu[req_idx], + int(self.prefill_len.np[req_idx]), + int(self.prompt_len[req_idx]), + self.prompt_bin_mask[req_idx], + self.output_bin_counts[req_idx], + ) + self._penalties_reqs.clear() + def remove_request(self, req_id: str) -> None: self.extra_data.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None) @@ -208,30 +209,25 @@ class RequestState: idx_mapping_np: np.ndarray, pos: torch.Tensor, ) -> SamplingMetadata: - temperature = self.temperature.np[idx_mapping_np] - temperature = self.temperature.copy_np_to_gpu(temperature) + temperature = self.temperature.copy_to_uva() top_p = self.top_p.np[idx_mapping_np] no_top_p = np.all(top_p == 1.0) - top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None + top_p = self.top_p.copy_to_uva()[idx_mapping] if not no_top_p else None top_k = self.top_k.np[idx_mapping_np] no_top_k = np.all(top_k == self.vocab_size) - top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None + top_k = self.top_k.copy_to_uva()[idx_mapping] if not no_top_k else None min_p = self.min_p.np[idx_mapping_np] no_min_p = np.all(min_p == 0.0) - min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None + min_p = self.min_p.copy_to_uva() if not no_min_p else None - rep_penalty = self.repetition_penalty.np[idx_mapping_np] - rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty) - freq_penalty = self.frequency_penalty.np[idx_mapping_np] - freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty) - pres_penalty = self.presence_penalty.np[idx_mapping_np] - pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty) + rep_penalty = self.repetition_penalty.copy_to_uva() + freq_penalty = self.frequency_penalty.copy_to_uva() + pres_penalty = self.presence_penalty.copy_to_uva() - seeds = self.seeds.np[idx_mapping_np] - seeds = self.seeds.copy_np_to_gpu(seeds) + seeds = self.seeds.copy_to_uva() num_logprobs = self.num_logprobs[idx_mapping_np] max_num_logprobs: int | None = int(np.max(num_logprobs)) @@ -239,6 +235,7 @@ class RequestState: max_num_logprobs = None return SamplingMetadata( + idx_mapping=idx_mapping, temperature=temperature, top_p=top_p, top_k=top_k, @@ -246,12 +243,11 @@ class RequestState: repetition_penalty=rep_penalty, frequency_penalty=freq_penalty, presence_penalty=pres_penalty, + prompt_bin_mask=self.prompt_bin_mask, + output_bin_counts=self.output_bin_counts, seeds=seeds, pos=pos, max_num_logprobs=max_num_logprobs, - idx_mapping=idx_mapping, - prompt_bin_mask=self.prompt_bin_mask, - output_bin_counts=self.output_bin_counts, ) def make_lora_inputs( @@ -272,42 +268,12 @@ class RequestState: return prompt_lora_mapping, token_lora_mapping, active_lora_requests -class Param: - def __init__( - self, - size: int, - dtype: torch.dtype, - device: torch.device, - pin_memory: bool, - ): - self.buffer = CpuGpuBuffer( - size, - dtype=dtype, - device=device, - pin_memory=pin_memory, - ) - self.np = np.zeros_like(self.buffer.np) - - def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor: - n = x.shape[0] - self.buffer.np[:n] = x - return self.buffer.copy_to_gpu(n) - - @dataclass class ExtraData: lora_request: LoRARequest | None in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) -class UvaBuffer: - def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype): - assert is_uva_available() - self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True) - self.np = self.cpu.numpy() - self.gpu = get_cuda_view_from_cpu_tensor(self.cpu) - - def use_penalty(sampling_params: SamplingParams) -> bool: return ( sampling_params.repetition_penalty != 1.0 diff --git a/vllm/v1/worker/gpu/structured_outputs.py b/vllm/v1/worker/gpu/structured_outputs.py index 83051b0ed..2eaadbb0b 100644 --- a/vllm/v1/worker/gpu/structured_outputs.py +++ b/vllm/v1/worker/gpu/structured_outputs.py @@ -4,38 +4,65 @@ import numpy as np import torch from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu.input_batch import InputBuffers +from vllm.utils.math_utils import cdiv +from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool +from vllm.v1.worker.gpu.input_batch import InputBatch -def apply_grammar_bitmask( - logits: torch.Tensor, - req_ids: list[str], - grammar_req_ids: list[str], - grammar_bitmask: np.ndarray, - input_buffers: InputBuffers, -) -> None: - input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask - input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0]) +class StructuredOutputsWorker: + def __init__( + self, + max_num_logits: int, + vocab_size: int, + ): + # NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor + # to save a unnecessary CPU-to-CPU copy. + self.logits_indices = UvaBufferPool(max_num_logits, torch.int32) + self.grammar_bitmask = UvaBufferPool( + (max_num_logits, cdiv(vocab_size, 32)), torch.int32 + ) - batch_size = logits.shape[0] - grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)} - # logits -> bitmask mapping - mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids] - input_buffers.bitmask_indices.np[:batch_size] = mapping - input_buffers.bitmask_indices.copy_to_gpu(batch_size) + def apply_grammar_bitmask( + self, + logits: torch.Tensor, + input_batch: InputBatch, + grammar_req_ids: list[str], + grammar_bitmask: np.ndarray, + ) -> None: + if not grammar_req_ids: + return - vocab_size = logits.shape[-1] - BLOCK_SIZE = 8192 - grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE)) - _apply_grammar_bitmask_kernel[grid]( - logits, - logits.stride(0), - input_buffers.grammar_bitmask.gpu, - input_buffers.grammar_bitmask.gpu.stride(0), - input_buffers.bitmask_indices.gpu, - vocab_size, - BLOCK_SIZE=BLOCK_SIZE, - ) + # Construct bitmask -> logits mapping + mapping: list[int] = [] + req_ids = input_batch.req_ids + cu_num_logits = input_batch.cu_num_logits_np.tolist() + req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)} + for grammar_req_id in grammar_req_ids: + req_idx = req_id_to_idx[grammar_req_id] + logits_start_idx = cu_num_logits[req_idx] + logits_end_idx = cu_num_logits[req_idx + 1] + mapping.extend(range(logits_start_idx, logits_end_idx)) + # Copy the mapping. + mapping_np = np.array(mapping, dtype=np.int32) + logits_indices = self.logits_indices.copy_to_uva(mapping_np) + + # Copy the bitmask. + bitmask = self.grammar_bitmask.copy_to_uva(grammar_bitmask) + + num_masks = bitmask.shape[0] + assert num_masks == len(mapping) + vocab_size = logits.shape[-1] + BLOCK_SIZE = 8192 + grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE)) + _apply_grammar_bitmask_kernel[grid]( + logits, + logits.stride(0), + logits_indices, + bitmask, + bitmask.stride(0), + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) # Adapted from @@ -44,17 +71,14 @@ def apply_grammar_bitmask( def _apply_grammar_bitmask_kernel( logits_ptr, logits_stride, + logits_indices_ptr, bitmask_ptr, bitmask_stride, - bitmask_indices_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, ): - logits_idx = tl.program_id(0) - bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx) - if bitmask_idx == -1: - # No bitmask to apply. - return + bitmask_idx = tl.program_id(0) + logits_idx = tl.load(logits_indices_ptr + bitmask_idx) # Load the bitmask. block_id = tl.program_id(1)