diff --git a/vllm/v1/worker/gpu/buffer_utils.py b/vllm/v1/worker/gpu/buffer_utils.py index 0662f0278..dced199aa 100644 --- a/vllm/v1/worker/gpu/buffer_utils.py +++ b/vllm/v1/worker/gpu/buffer_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence +from collections.abc import Iterable, Sequence import numpy as np import torch @@ -92,9 +92,10 @@ class StagedWriteTensor: max_concurrency: int = 2, uva_instead_of_gpu: bool = False, ): - if dtype not in [torch.int32, torch.int64]: + supported_dtypes = [torch.int32, torch.int64, torch.float32] + if dtype not in supported_dtypes: raise ValueError( - f"Unsupported dtype {dtype}: should be either int32 or int64" + f"Unsupported dtype {dtype}: should be one of {supported_dtypes}" ) self.num_rows = size if isinstance(size, int) else size[0] self.dtype = dtype @@ -111,7 +112,7 @@ class StagedWriteTensor: self._staged_write_indices: list[int] = [] self._staged_write_starts: list[int] = [] - self._staged_write_contents: list[int] = [] + self._staged_write_contents: list[int | float] = [] self._staged_write_cu_lens: list[int] = [] self.write_indices = UvaBufferPool( @@ -128,7 +129,12 @@ class StagedWriteTensor: self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency ) - def stage_write(self, index: int, start: int, x: list[int]) -> None: + def stage_write( + self, + index: int, + start: int, + x: Iterable[int] | Iterable[float], + ) -> None: assert index >= 0 assert start >= 0 if not x: diff --git a/vllm/v1/worker/gpu/sample/logit_bias.py b/vllm/v1/worker/gpu/sample/logit_bias.py new file mode 100644 index 000000000..396e177d0 --- /dev/null +++ b/vllm/v1/worker/gpu/sample/logit_bias.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.sampling_params import SamplingParams +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor + +MAX_NUM_ALLOWED_TOKEN_IDS = 1024 +MAX_NUM_LOGIT_BIAS_TOKENS = 1024 +MAX_NUM_STOP_TOKEN_IDS = 128 + + +class LogitBiasState: + def __init__( + self, + max_num_reqs: int, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + + # Allowed token IDs. + self.num_allowed_token_ids = UvaBackedTensor( + self.max_num_reqs, dtype=torch.int32 + ) + self.allowed_token_ids = StagedWriteTensor( + (self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS), + dtype=torch.int32, + device=device, + ) + # Logit bias. + self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) + self.logit_bias_token_ids = StagedWriteTensor( + (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS), + dtype=torch.int32, + device=device, + ) + self.logit_bias = StagedWriteTensor( + (self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS), + dtype=torch.float32, + device=device, + ) + # Min tokens. + self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) + self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) + self.stop_token_ids = StagedWriteTensor( + (self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS), + dtype=torch.int32, + device=device, + ) + + def add_request( + self, + req_idx: int, + prompt_len: int, + sampling_params: SamplingParams, + ) -> None: + # Allowed token IDs. + allowed_token_ids = sampling_params.allowed_token_ids + if allowed_token_ids: + num_allowed_token_ids = len(allowed_token_ids) + if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS: + raise ValueError( + f"Too many allowed token IDs: {num_allowed_token_ids}. " + f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}." + ) + self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids + self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids) + else: + self.num_allowed_token_ids.np[req_idx] = 0 + + # Logit bias. + logit_bias = sampling_params.logit_bias + if logit_bias: + num_logit_bias = len(logit_bias) + if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS: + raise ValueError( + f"Too many logit bias tokens: {num_logit_bias}. " + f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}." + ) + self.num_logit_bias.np[req_idx] = num_logit_bias + self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys()) + self.logit_bias.stage_write(req_idx, 0, logit_bias.values()) + else: + self.num_logit_bias.np[req_idx] = 0 + + # Min tokens. + min_tokens = sampling_params.min_tokens + min_len = prompt_len + min_tokens + self.min_lens.np[req_idx] = min_len + stop_token_ids = sampling_params.all_stop_token_ids + if stop_token_ids: + num_stop_token_ids = len(stop_token_ids) + if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS: + raise ValueError( + f"Too many stop tokens: {num_stop_token_ids}. " + f"The max size is {MAX_NUM_STOP_TOKEN_IDS}." + ) + self.num_stop_token_ids.np[req_idx] = num_stop_token_ids + self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids) + else: + self.num_stop_token_ids.np[req_idx] = 0 + + def apply_staged_writes(self) -> None: + self.num_allowed_token_ids.copy_to_uva() + self.allowed_token_ids.apply_write() + + self.num_logit_bias.copy_to_uva() + self.logit_bias_token_ids.apply_write() + self.logit_bias.apply_write() + + self.min_lens.copy_to_uva() + self.num_stop_token_ids.copy_to_uva() + self.stop_token_ids.apply_write() + + def apply_logit_bias( + self, + logits: torch.Tensor, + idx_mapping: torch.Tensor, + pos: torch.Tensor, + ) -> None: + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = triton.next_power_of_2( + max( + MAX_NUM_ALLOWED_TOKEN_IDS, + MAX_NUM_LOGIT_BIAS_TOKENS, + MAX_NUM_STOP_TOKEN_IDS, + ) + ) + LOGITS_BLOCK_SIZE = 8192 + _bias_kernel[(num_reqs,)]( + logits, + logits.stride(0), + vocab_size, + idx_mapping, + self.num_allowed_token_ids, + self.allowed_token_ids, + self.allowed_token_ids.gpu.stride(0), + self.num_logit_bias, + self.logit_bias_token_ids, + self.logit_bias_token_ids.gpu.stride(0), + self.logit_bias, + self.logit_bias.gpu.stride(0), + pos, + self.min_lens, + self.num_stop_token_ids, + self.stop_token_ids, + self.stop_token_ids.gpu.stride(0), + BLOCK_SIZE=BLOCK_SIZE, + LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE, + ) + + +@triton.jit +def _bias_kernel( + logits_ptr, + logits_stride, + vocab_size, + idx_mapping_ptr, + # Allowed token IDs. + num_allowed_token_ids_ptr, + allowed_token_ids_ptr, + allowed_token_ids_stride, + # Logit bias. + num_logit_bias_ptr, + bias_token_ids_ptr, + bias_token_ids_stride, + bias_ptr, + bias_stride, + # Min tokens. + pos_ptr, + min_lens_ptr, + num_stop_token_ids_ptr, + stop_token_ids_ptr, + stop_token_ids_stride, + BLOCK_SIZE: tl.constexpr, + LOGITS_BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + block = tl.arange(0, BLOCK_SIZE) + + # Allowed token IDs. + num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx) + if num_allowed_token_ids > 0: + block = tl.arange(0, BLOCK_SIZE) + mask = block < num_allowed_token_ids + + # Save logits for allowed token IDs. + allowed_token_ids = tl.load( + allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block, + mask=mask, + ) + logits = tl.load( + logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask + ) + + # Set logits to -inf for all tokens. + for i in range(0, vocab_size, LOGITS_BLOCK_SIZE): + offset = i + tl.arange(0, LOGITS_BLOCK_SIZE) + tl.store( + logits_ptr + batch_idx * logits_stride + offset, + -float("inf"), + mask=offset < vocab_size, + ) + + # Restore logits for allowed token IDs. + tl.store( + logits_ptr + batch_idx * logits_stride + allowed_token_ids, + logits, + mask=mask, + ) + + # Logit bias. + num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx) + if num_logit_bias > 0: + mask = block < num_logit_bias + token_ids = tl.load( + bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block, + mask=mask, + ) + bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask) + logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask) + logits += bias + tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask) + + # Apply min tokens. + num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx) + pos = tl.load(pos_ptr + batch_idx) + min_len = tl.load(min_lens_ptr + req_state_idx) + if num_stop_token_ids > 0 and pos < min_len: + mask = block < num_stop_token_ids + stop_token_ids = tl.load( + stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block, + mask=mask, + ) + tl.store( + logits_ptr + batch_idx * logits_stride + stop_token_ids, + -float("inf"), + mask=mask, + )