[Model Runner V2] support bad_words sampling param (#33433)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -156,8 +156,8 @@ def _prepare_prefill_inputs_kernel(
|
||||
next_prefill_tokens_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
prefill_token_ids_ptr,
|
||||
prefill_token_ids_stride,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prefill_lens_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
@@ -174,16 +174,16 @@ def _prepare_prefill_inputs_kernel(
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
|
||||
request_ptr = all_token_ids_ptr + req_state_idx * all_token_ids_stride
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
|
||||
tokens = tl.load(request_ptr + num_computed + block, mask=mask)
|
||||
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
|
||||
|
||||
next_pos = num_computed + query_len
|
||||
if next_pos < prefill_len:
|
||||
next_token = tl.load(prefill_ptr + next_pos)
|
||||
next_token = tl.load(request_ptr + next_pos)
|
||||
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
|
||||
|
||||
|
||||
@@ -192,7 +192,7 @@ def prepare_prefill_inputs(
|
||||
next_prefill_tokens: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
prefill_token_ids: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
) -> None:
|
||||
@@ -202,8 +202,8 @@ def prepare_prefill_inputs(
|
||||
next_prefill_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
prefill_token_ids,
|
||||
prefill_token_ids.stride(0),
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prefill_len,
|
||||
num_computed_tokens,
|
||||
BLOCK_SIZE=1024,
|
||||
@@ -423,16 +423,21 @@ def _post_update_kernel(
|
||||
num_sampled_ptr,
|
||||
num_rejected_ptr,
|
||||
query_start_loc_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
total_len_ptr,
|
||||
):
|
||||
req_id = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_id)
|
||||
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
num_sampled = tl.load(num_sampled_ptr + req_id)
|
||||
if num_sampled > 0:
|
||||
token_id = tl.load(
|
||||
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
|
||||
)
|
||||
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
|
||||
tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
|
||||
|
||||
for i in range(num_sampled):
|
||||
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
|
||||
@@ -442,6 +447,10 @@ def _post_update_kernel(
|
||||
count = tl.load(token_ptr)
|
||||
count += 1
|
||||
tl.store(token_ptr, count)
|
||||
tl.store(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
|
||||
token_id,
|
||||
)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + req_id)
|
||||
query_end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
@@ -470,6 +479,10 @@ def post_update(
|
||||
num_rejected: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
# [max_num_reqs, max_model_len]
|
||||
all_token_ids: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
total_len: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_kernel[(num_reqs,)](
|
||||
@@ -483,6 +496,9 @@ def post_update(
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
query_start_loc,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
total_len,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,6 +151,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
all_token_ids=self.req_states.all_token_ids.gpu,
|
||||
prompt_len=self.req_states.prompt_len.gpu,
|
||||
total_len=self.req_states.total_len.gpu,
|
||||
logprobs_mode=self.model_config.logprobs_mode,
|
||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||
)
|
||||
@@ -448,7 +451,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
prompt_len=prompt_len,
|
||||
prefill_token_ids=new_req_data.prefill_token_ids,
|
||||
all_token_ids=new_req_data.prefill_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
)
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
@@ -479,9 +482,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes(
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prefill_len.np,
|
||||
self.req_states.prompt_len,
|
||||
self.req_states.prompt_len.np,
|
||||
)
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
@@ -570,7 +573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.next_prefill_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
@@ -759,6 +762,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
input_batch.query_start_loc,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.total_len.gpu,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
@@ -924,9 +929,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model.compute_logits,
|
||||
hidden_states,
|
||||
input_batch,
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
self.req_states.prompt_len,
|
||||
self.req_states.prompt_len.np,
|
||||
self.req_states.prefill_len.np,
|
||||
self.req_states.num_computed_prefill_tokens,
|
||||
)
|
||||
|
||||
209
vllm/v1/worker/gpu/sample/bad_words.py
Normal file
209
vllm/v1/worker/gpu/sample/bad_words.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
|
||||
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
|
||||
|
||||
|
||||
class BadWordsState:
|
||||
def __init__(
|
||||
self,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
total_len: torch.Tensor,
|
||||
):
|
||||
self.all_token_ids = all_token_ids
|
||||
self.prompt_len = prompt_len
|
||||
self.total_len = total_len
|
||||
|
||||
self.max_num_reqs = prompt_len.shape[0]
|
||||
self.device = prompt_len.device
|
||||
|
||||
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
|
||||
self.bad_word_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# cumulative offsets of bad words: [max_num_reqs, MAX_NUM_BAD_WORDS + 1]
|
||||
self.bad_word_offsets = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_BAD_WORDS + 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# number of bad words per request
|
||||
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
# whether request uses bad words
|
||||
self.use_bad_words = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_idx: int,
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
bad_words_token_ids = sampling_params.bad_words_token_ids
|
||||
if not bad_words_token_ids:
|
||||
self.num_bad_words.np[req_idx] = 0
|
||||
self.use_bad_words[req_idx] = False
|
||||
return
|
||||
|
||||
num_bad_words = len(bad_words_token_ids)
|
||||
if num_bad_words > MAX_NUM_BAD_WORDS:
|
||||
raise ValueError(
|
||||
f"Too many bad words: {num_bad_words}. "
|
||||
f"The max number is {MAX_NUM_BAD_WORDS}."
|
||||
)
|
||||
|
||||
# Flatten bad words and compute offsets
|
||||
flattened_tokens: list[int] = []
|
||||
offsets: list[int] = [0]
|
||||
for bad_word in bad_words_token_ids:
|
||||
flattened_tokens.extend(bad_word)
|
||||
offsets.append(len(flattened_tokens))
|
||||
|
||||
if len(flattened_tokens) > MAX_BAD_WORDS_TOTAL_TOKENS:
|
||||
raise ValueError(
|
||||
f"Too many total bad word tokens: {len(flattened_tokens)}. "
|
||||
f"The max is {MAX_BAD_WORDS_TOTAL_TOKENS}."
|
||||
)
|
||||
|
||||
# Stage writes
|
||||
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
|
||||
self.bad_word_offsets.stage_write(req_idx, 0, offsets)
|
||||
self.num_bad_words.np[req_idx] = num_bad_words
|
||||
self.use_bad_words[req_idx] = True
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.num_bad_words.copy_to_uva()
|
||||
self.bad_word_token_ids.apply_write()
|
||||
self.bad_word_offsets.apply_write()
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
) -> None:
|
||||
if not np.any(self.use_bad_words[idx_mapping_np]):
|
||||
# No request uses bad words. Skip the kernel launch.
|
||||
return
|
||||
|
||||
actual_max_num_bad_words = int(np.max(self.num_bad_words.np[idx_mapping_np]))
|
||||
apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
self.all_token_ids,
|
||||
self.prompt_len,
|
||||
self.total_len,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
actual_max_num_bad_words,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bad_words_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
expanded_idx_mapping_ptr,
|
||||
bad_word_token_ids_ptr,
|
||||
bad_word_token_ids_stride,
|
||||
bad_word_offsets_ptr,
|
||||
bad_word_offsets_stride,
|
||||
num_bad_words_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
total_len_ptr,
|
||||
input_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
bw_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
|
||||
|
||||
if bw_idx >= num_bad_words:
|
||||
return
|
||||
|
||||
pos = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
cur_req_first_pos = logit_idx - pos
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
output_len = total_len - prompt_len
|
||||
effective_len = output_len + pos
|
||||
|
||||
bd_offsets_base = bad_word_offsets_ptr + req_state_idx * bad_word_offsets_stride
|
||||
bd_tokens_base = bad_word_token_ids_ptr + req_state_idx * bad_word_token_ids_stride
|
||||
output_base = all_token_ids_ptr + req_state_idx * all_token_ids_stride + prompt_len
|
||||
|
||||
start = tl.load(bd_offsets_base + bw_idx)
|
||||
end = tl.load(bd_offsets_base + bw_idx + 1)
|
||||
bad_word_len = end - start
|
||||
prefix_len = bad_word_len - 1
|
||||
|
||||
if prefix_len > effective_len:
|
||||
return
|
||||
|
||||
last_token = tl.load(bd_tokens_base + end - 1)
|
||||
match = 1
|
||||
for i in range(prefix_len):
|
||||
expected = tl.load(bd_tokens_base + start + i)
|
||||
actual_pos = effective_len - prefix_len + i
|
||||
|
||||
from_spec_input = actual_pos >= output_len
|
||||
if from_spec_input:
|
||||
spec_offset = actual_pos - output_len
|
||||
actual = tl.load(input_ids_ptr + cur_req_first_pos + spec_offset)
|
||||
else:
|
||||
actual = tl.load(output_base + actual_pos)
|
||||
|
||||
match = match & (expected == actual)
|
||||
|
||||
if match:
|
||||
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
bad_word_token_ids: torch.Tensor,
|
||||
bad_word_offsets: torch.Tensor,
|
||||
num_bad_words: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
total_len: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
max_num_bad_words: int,
|
||||
) -> None:
|
||||
total_num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
expanded_idx_mapping,
|
||||
bad_word_token_ids,
|
||||
bad_word_token_ids.stride(0),
|
||||
bad_word_offsets,
|
||||
bad_word_offsets.stride(0),
|
||||
num_bad_words,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
total_len,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
@@ -51,14 +51,14 @@ class PenaltiesState:
|
||||
|
||||
def apply_staged_writes(
|
||||
self,
|
||||
prefill_token_ids: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_lens: np.ndarray,
|
||||
prompt_lens: np.ndarray,
|
||||
) -> None:
|
||||
# TODO(woosuk): Optimize this.
|
||||
for req_idx in self._penalties_reqs:
|
||||
bincount(
|
||||
prefill_token_ids[req_idx],
|
||||
all_token_ids[req_idx],
|
||||
int(prefill_lens[req_idx]),
|
||||
int(prompt_lens[req_idx]),
|
||||
self.prompt_bin_mask[req_idx],
|
||||
@@ -216,7 +216,7 @@ def apply_penalties(
|
||||
|
||||
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
|
||||
def _bincount_kernel(
|
||||
prefill_token_ids_ptr,
|
||||
all_token_ids_ptr,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_mask_ptr,
|
||||
@@ -230,20 +230,20 @@ def _bincount_kernel(
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
if block_idx * BLOCK_SIZE < prompt_len:
|
||||
mask = block < prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
idx = prefill_tokens // 32
|
||||
bit_idx = prefill_tokens % 32
|
||||
prompt_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
|
||||
idx = prompt_tokens // 32
|
||||
bit_idx = prompt_tokens % 32
|
||||
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
|
||||
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
|
||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||
mask = block < prefill_len
|
||||
mask &= block >= prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||
output_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(output_bin_counts_ptr + output_tokens, 1, mask=mask)
|
||||
|
||||
|
||||
def bincount(
|
||||
prefill_token_ids: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_len: int,
|
||||
prompt_len: int,
|
||||
prompt_bin_mask: torch.Tensor,
|
||||
@@ -254,7 +254,7 @@ def bincount(
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_blocks,)](
|
||||
prefill_token_ids,
|
||||
all_token_ids,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_mask,
|
||||
|
||||
@@ -36,7 +36,7 @@ class PromptLogprobsWorker:
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
# [max_num_reqs, max_model_len]
|
||||
prefill_token_ids: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
@@ -70,7 +70,7 @@ class PromptLogprobsWorker:
|
||||
input_batch.query_start_loc,
|
||||
input_batch.idx_mapping,
|
||||
num_computed_tokens,
|
||||
prefill_token_ids,
|
||||
all_token_ids,
|
||||
)
|
||||
# Compute the prompt logprobs.
|
||||
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
|
||||
@@ -132,8 +132,8 @@ def _prompt_logprobs_token_ids_kernel(
|
||||
query_start_loc_ptr,
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
prefill_token_ids_ptr,
|
||||
prefill_token_ids_stride,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
@@ -151,9 +151,7 @@ def _prompt_logprobs_token_ids_kernel(
|
||||
# because the logprob is computed for the next token.
|
||||
target_pos = num_computed_tokens + 1 + block
|
||||
token_ids = tl.load(
|
||||
prefill_token_ids_ptr
|
||||
+ req_state_idx * prefill_token_ids_stride
|
||||
+ target_pos,
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + target_pos,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
@@ -166,7 +164,7 @@ def get_prompt_logprobs_token_ids(
|
||||
query_start_loc: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
prefill_token_ids: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
@@ -175,8 +173,8 @@ def get_prompt_logprobs_token_ids(
|
||||
query_start_loc,
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
prefill_token_ids,
|
||||
prefill_token_ids.stride(0),
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return token_ids
|
||||
|
||||
@@ -8,6 +8,7 @@ import vllm.envs as envs
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
@@ -22,6 +23,9 @@ class Sampler:
|
||||
max_num_reqs: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
total_len: torch.Tensor,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
num_speculative_tokens: int = 1,
|
||||
):
|
||||
@@ -33,6 +37,7 @@ class Sampler:
|
||||
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
|
||||
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
|
||||
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
|
||||
self.bad_words_state = BadWordsState(all_token_ids, prompt_len, total_len)
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
|
||||
def add_request(
|
||||
@@ -41,18 +46,20 @@ class Sampler:
|
||||
self.sampling_states.add_request(req_idx, sampling_params)
|
||||
self.penalties_state.add_request(req_idx, sampling_params)
|
||||
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
|
||||
self.bad_words_state.add_request(req_idx, sampling_params)
|
||||
|
||||
def apply_staged_writes(
|
||||
self,
|
||||
prefill_token_ids: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_lens: np.ndarray,
|
||||
prompt_lens: np.ndarray,
|
||||
) -> None:
|
||||
self.sampling_states.apply_staged_writes()
|
||||
self.penalties_state.apply_staged_writes(
|
||||
prefill_token_ids, prefill_lens, prompt_lens
|
||||
all_token_ids, prefill_lens, prompt_lens
|
||||
)
|
||||
self.logit_bias_state.apply_staged_writes()
|
||||
self.bad_words_state.apply_staged_writes()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -124,6 +131,15 @@ class Sampler:
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
|
||||
@@ -27,17 +27,30 @@ class RequestState:
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
|
||||
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# depending on the configured max_num_reqs and max_model_len.
|
||||
# To save GPU memory, we use UVA instead of GPU for this tensor.
|
||||
self.prefill_token_ids = StagedWriteTensor(
|
||||
self.all_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
# NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
|
||||
# - prompt_len: Number of tokens in the user-provided prompt.
|
||||
# - prefill_len: Number of tokens passed into the model runner.
|
||||
# This can include the prompt and additional partial output tokens,
|
||||
# so prefill_len >= prompt_len.
|
||||
# Usually, prefill_len equals prompt_len, but in cases such as resumption after
|
||||
# preemption, prefill_len may be greater. Differentiating between these values
|
||||
# is crucial, as certain features such as prompt logprobs or frequency penalties
|
||||
# must treat prompt and output tokens separately.
|
||||
self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
# total_len = prompt_len + output_len. It grows as the request progresses.
|
||||
self.total_len = StagedWriteTensor(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Number of computed tokens.
|
||||
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
@@ -72,7 +85,7 @@ class RequestState:
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_len: int,
|
||||
prefill_token_ids: list[int],
|
||||
all_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0, "No free indices"
|
||||
@@ -80,19 +93,22 @@ class RequestState:
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
self.prompt_len[req_idx] = prompt_len
|
||||
prefill_len = len(prefill_token_ids)
|
||||
self.prompt_len.np[req_idx] = prompt_len
|
||||
prefill_len = len(all_token_ids)
|
||||
assert prefill_len >= prompt_len, (
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.prefill_token_ids.stage_write(req_idx, 0, prefill_token_ids)
|
||||
self.total_len.stage_write_elem(req_idx, prefill_len)
|
||||
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
|
||||
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
|
||||
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prompt_len.copy_to_uva()
|
||||
self.prefill_len.copy_to_uva()
|
||||
self.prefill_token_ids.apply_write()
|
||||
self.total_len.apply_write()
|
||||
self.all_token_ids.apply_write()
|
||||
self.num_computed_tokens.apply_write()
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user