[Model Runner V2] Minor refactoring for penalties (#34662)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -155,9 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
all_token_ids=self.req_states.all_token_ids.gpu,
|
||||
prompt_len=self.req_states.prompt_len.gpu,
|
||||
total_len=self.req_states.total_len.gpu,
|
||||
req_states=self.req_states,
|
||||
logprobs_mode=self.model_config.logprobs_mode,
|
||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||
)
|
||||
@@ -528,11 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes(
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prefill_len.np,
|
||||
self.req_states.prompt_len.np,
|
||||
)
|
||||
self.sampler.apply_staged_writes()
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
|
||||
|
||||
@@ -6,24 +6,17 @@ import torch
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
|
||||
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
|
||||
|
||||
|
||||
class BadWordsState:
|
||||
def __init__(
|
||||
self,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
total_len: torch.Tensor,
|
||||
):
|
||||
self.all_token_ids = all_token_ids
|
||||
self.prompt_len = prompt_len
|
||||
self.total_len = total_len
|
||||
|
||||
self.max_num_reqs = prompt_len.shape[0]
|
||||
self.device = prompt_len.device
|
||||
def __init__(self, req_states: RequestState):
|
||||
self.req_states = req_states
|
||||
self.max_num_reqs = req_states.max_num_reqs
|
||||
self.device = req_states.device
|
||||
|
||||
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
|
||||
self.bad_word_token_ids = StagedWriteTensor(
|
||||
@@ -95,9 +88,9 @@ class BadWordsState:
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
self.all_token_ids,
|
||||
self.prompt_len,
|
||||
self.total_len,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prompt_len.gpu,
|
||||
self.req_states.total_len.gpu,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
max_num_bad_words,
|
||||
|
||||
@@ -6,14 +6,18 @@ import torch
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import async_tensor_h2d
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class PenaltiesState:
|
||||
def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
def __init__(self, req_states: RequestState):
|
||||
self.req_states = req_states
|
||||
|
||||
max_num_reqs = req_states.max_num_reqs
|
||||
self.vocab_size = req_states.vocab_size
|
||||
self.device = req_states.device
|
||||
|
||||
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
@@ -26,7 +30,7 @@ class PenaltiesState:
|
||||
|
||||
# Statistics for penalties.
|
||||
self.prompt_bin_mask = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
max_num_reqs,
|
||||
cdiv(self.vocab_size, 32),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
@@ -34,10 +38,10 @@ class PenaltiesState:
|
||||
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
|
||||
# GBs of GPU memory. Optimize the memory usage.
|
||||
self.output_bin_counts = torch.zeros(
|
||||
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
|
||||
max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self._penalties_reqs: list[int] = []
|
||||
self._new_penalties_reqs: list[int] = []
|
||||
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
|
||||
@@ -47,24 +51,29 @@ class PenaltiesState:
|
||||
do_penalty = use_penalty(sampling_params)
|
||||
self.use_penalty[req_idx] = do_penalty
|
||||
if do_penalty:
|
||||
self._penalties_reqs.append(req_idx)
|
||||
self._new_penalties_reqs.append(req_idx)
|
||||
|
||||
def apply_staged_writes(
|
||||
self,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_lens: np.ndarray,
|
||||
prompt_lens: np.ndarray,
|
||||
) -> None:
|
||||
# TODO(woosuk): Optimize this.
|
||||
for req_idx in self._penalties_reqs:
|
||||
bincount(
|
||||
all_token_ids[req_idx],
|
||||
int(prefill_lens[req_idx]),
|
||||
int(prompt_lens[req_idx]),
|
||||
self.prompt_bin_mask[req_idx],
|
||||
self.output_bin_counts[req_idx],
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self._new_penalties_reqs:
|
||||
idx_mapping = async_tensor_h2d(
|
||||
self._new_penalties_reqs,
|
||||
dtype=torch.int32,
|
||||
target_device=self.device,
|
||||
pin_memory=True,
|
||||
)
|
||||
self._penalties_reqs.clear()
|
||||
|
||||
prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
|
||||
max_prefill_len = int(prefill_lens.max())
|
||||
bincount(
|
||||
idx_mapping,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prompt_len.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.prompt_bin_mask,
|
||||
self.output_bin_counts,
|
||||
max_prefill_len,
|
||||
)
|
||||
self._new_penalties_reqs.clear()
|
||||
|
||||
self.repetition_penalty.copy_to_uva()
|
||||
self.frequency_penalty.copy_to_uva()
|
||||
@@ -214,51 +223,82 @@ def apply_penalties(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
prefill_len_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
return
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
if block_idx * BLOCK_SIZE < prompt_len:
|
||||
mask = block < prompt_len
|
||||
prompt_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
|
||||
prompt_tokens = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
|
||||
)
|
||||
idx = prompt_tokens // 32
|
||||
bit_idx = prompt_tokens % 32
|
||||
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
|
||||
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
|
||||
tl.atomic_or(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
|
||||
bit,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||
mask = block < prefill_len
|
||||
mask &= block >= prompt_len
|
||||
output_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(output_bin_counts_ptr + output_tokens, 1, mask=mask)
|
||||
output_tokens = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
|
||||
)
|
||||
tl.atomic_add(
|
||||
output_bin_counts_ptr
|
||||
+ req_state_idx * output_bin_counts_stride
|
||||
+ output_tokens,
|
||||
1,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_len: int,
|
||||
prompt_len: int,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
prompt_bin_mask: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask.zero_()
|
||||
output_bin_counts.zero_()
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_blocks,)](
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
all_token_ids,
|
||||
prefill_len,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
prefill_len,
|
||||
prompt_bin_mask,
|
||||
prompt_bin_mask.stride(0),
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
|
||||
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class Sampler:
|
||||
@@ -23,9 +24,7 @@ class Sampler:
|
||||
max_num_reqs: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
total_len: torch.Tensor,
|
||||
req_states: RequestState,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
num_speculative_tokens: int = 1,
|
||||
):
|
||||
@@ -35,9 +34,9 @@ class Sampler:
|
||||
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
|
||||
|
||||
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
|
||||
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
|
||||
self.penalties_state = PenaltiesState(req_states)
|
||||
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
|
||||
self.bad_words_state = BadWordsState(all_token_ids, prompt_len, total_len)
|
||||
self.bad_words_state = BadWordsState(req_states)
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
|
||||
def add_request(
|
||||
@@ -48,16 +47,9 @@ class Sampler:
|
||||
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
|
||||
self.bad_words_state.add_request(req_idx, sampling_params)
|
||||
|
||||
def apply_staged_writes(
|
||||
self,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_lens: np.ndarray,
|
||||
prompt_lens: np.ndarray,
|
||||
) -> None:
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.sampling_states.apply_staged_writes()
|
||||
self.penalties_state.apply_staged_writes(
|
||||
all_token_ids, prefill_lens, prompt_lens
|
||||
)
|
||||
self.penalties_state.apply_staged_writes()
|
||||
self.logit_bias_state.apply_staged_writes()
|
||||
self.bad_words_state.apply_staged_writes()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user