diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index 0cb80f833..3a0a6b6a0 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -5,6 +5,50 @@ import torch from vllm.triton_utils import tl, triton +@triton.jit +def _temperature_kernel( + logits_ptr, + logits_stride, + idx_mapping_ptr, + temperature_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32) + if temperature == 0.0 or temperature == 1.0: + # Early return to avoid loading logits. + return + + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + + logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) + logits = logits.to(tl.float32) + logits = logits / temperature + tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) + + +def apply_temperature( + logits: torch.Tensor, + idx_mapping: torch.Tensor, + temperature: torch.Tensor, +) -> None: + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 8192 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + _temperature_kernel[(num_reqs, num_blocks)]( + logits, + logits.stride(0), + idx_mapping, + temperature, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + @triton.jit def _gumbel_sample_kernel( local_argmax_ptr, @@ -48,7 +92,7 @@ def _gumbel_sample_kernel( # Apply temperature. if APPLY_TEMPERATURE: - # NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel. + # NOTE(woosuk): Match the behavior of _temperature_kernel. # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. logits = logits / temp diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index 6226ff15e..4f0ce905f 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -66,16 +66,10 @@ class PenaltiesState: self.frequency_penalty.copy_to_uva() self.presence_penalty.copy_to_uva() - def apply_penalties_and_temperature( - self, - logits: torch.Tensor, - idx_mapping: torch.Tensor, - temperature: torch.Tensor, - ) -> None: - apply_penalties_and_temperature( + def apply_penalties(self, logits: torch.Tensor, idx_mapping: torch.Tensor) -> None: + apply_penalties( logits, idx_mapping, - temperature, self.repetition_penalty.gpu, self.frequency_penalty.gpu, self.presence_penalty.gpu, @@ -85,14 +79,13 @@ class PenaltiesState: @triton.jit -def _penalties_and_temperature_kernel( +def _penalties_kernel( logits_ptr, logits_stride, idx_mapping_ptr, repetition_penalty_ptr, frequency_penalty_ptr, presence_penalty_ptr, - temperature_ptr, prompt_bin_mask_ptr, prompt_bin_mask_stride, output_bin_counts_ptr, @@ -105,15 +98,12 @@ def _penalties_and_temperature_kernel( rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx) freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx) pres_penalty = tl.load(presence_penalty_ptr + req_state_idx) - temperature = tl.load(temperature_ptr + req_state_idx) - temperature = tl.where(temperature == 0.0, 1.0, temperature) use_rep_penalty = rep_penalty != 1.0 use_freq_penalty = freq_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0 use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty - use_temperature = temperature != 1.0 - if not (use_penalty or use_temperature): + if not use_penalty: # Early return to avoid loading logits. return @@ -123,47 +113,39 @@ def _penalties_and_temperature_kernel( logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) logits = logits.to(tl.float32) - if use_penalty: - output_bin_counts = tl.load( - output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, - mask=mask, + output_bin_counts = tl.load( + output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, + mask=mask, + ) + output_bin_mask = output_bin_counts > 0 + + # Apply repetition penalties. + if use_rep_penalty: + packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32) + packed_mask = tl.load( + prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block, + mask=packed_block < tl.cdiv(vocab_size, 32), ) - output_bin_mask = output_bin_counts > 0 + prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1 + prompt_bin_mask = prompt_bin_mask.to(tl.int1) + prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) - # Apply repetition penalties. - if use_rep_penalty: - packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32) - packed_mask = tl.load( - prompt_bin_mask_ptr - + req_state_idx * prompt_bin_mask_stride - + packed_block, - mask=packed_block < tl.cdiv(vocab_size, 32), - ) - prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1 - prompt_bin_mask = prompt_bin_mask.to(tl.int1) - prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) - - # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0) - # If logits are positive, divide by penalty, otherwise multiply by penalty. - logits *= tl.where(logits > 0, 1.0 / scale, scale) - - # Apply frequency penalties. - logits -= freq_penalty * output_bin_counts - # Apply presence penalties. - logits -= pres_penalty * output_bin_mask - - # Apply temperature. - logits = logits / temperature + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + logits *= tl.where(logits > 0, 1.0 / scale, scale) + # Apply frequency penalties. + logits -= freq_penalty * output_bin_counts + # Apply presence penalties. + logits -= pres_penalty * output_bin_mask # Store back to logits. tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) -def apply_penalties_and_temperature( +def apply_penalties( logits: torch.Tensor, idx_mapping: torch.Tensor, - temperature: torch.Tensor, repetition_penalty: torch.Tensor, frequency_penalty: torch.Tensor, presence_penalty: torch.Tensor, @@ -173,14 +155,13 @@ def apply_penalties_and_temperature( num_reqs, vocab_size = logits.shape BLOCK_SIZE = 8192 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) - _penalties_and_temperature_kernel[(num_reqs, num_blocks)]( + _penalties_kernel[(num_reqs, num_blocks)]( logits, logits.stride(0), idx_mapping, repetition_penalty, frequency_penalty, presence_penalty, - temperature, prompt_bin_mask, prompt_bin_mask.stride(0), output_bin_counts, diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index a9df2b48c..b8d2f0c5a 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -9,7 +9,7 @@ from vllm.config.model import LogprobsMode from vllm.sampling_params import SamplingParams from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.worker.gpu.metrics.logits import get_num_nans -from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.min_p import apply_min_p @@ -106,10 +106,11 @@ class Sampler: # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place. self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos) - # Apply penalties and temperature in place. - self.penalties_state.apply_penalties_and_temperature( - logits, idx_mapping, self.sampling_states.temperature.gpu - ) + # Apply penalties in place. + self.penalties_state.apply_penalties(logits, idx_mapping) + + # Apply temperature in place. + apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu) # Apply min_p in place if any request has a non-zero min_p. do_min_p = self.sampling_states.do_min_p(idx_mapping_np)