[V1] Adding min tokens/repetition/presence/frequence penalties to V1 sampler (#10681)
Signed-off-by: Sourashis Roy <sroy@roblox.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -11,6 +11,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
@@ -258,11 +259,11 @@ class Sampler(nn.Module):
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
if do_penalties:
|
||||
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
@@ -336,23 +337,6 @@ class Sampler(nn.Module):
|
||||
return self.should_modify_greedy_probs_inplace
|
||||
|
||||
|
||||
def _get_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_seqs: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Compute the bin counts for the tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=tokens.device)
|
||||
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
||||
bin_counts = bin_counts[:, :vocab_size]
|
||||
mask = bin_counts > 0
|
||||
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def _apply_min_tokens_penalty(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
@@ -400,29 +384,6 @@ def _apply_min_tokens_penalty(
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||
num_seqs, vocab_size = logits.shape
|
||||
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
|
||||
num_seqs)
|
||||
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
|
||||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
||||
logits = torch.where(logits > 0, logits / repetition_penalties,
|
||||
logits * repetition_penalties)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
|
||||
57
vllm/model_executor/layers/utils.py
Normal file
57
vllm/model_executor/layers/utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Utility methods for model layers."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_token_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_seqs: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Compute the bin counts for the tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=tokens.device)
|
||||
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
||||
bin_counts = bin_counts[:, :vocab_size]
|
||||
mask = bin_counts > 0
|
||||
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies penalties in place to the logits tensor
|
||||
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
||||
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
||||
are padded to the maximum prompt length within the batch using
|
||||
`vocab_size` as the padding value. The value `vocab_size` is used
|
||||
for padding because it does not correspond to any valid token ID
|
||||
in the vocabulary.
|
||||
output_tokens_tensor: The output tokens tensor.
|
||||
presence_penalties: The presence penalties of shape (num_seqs, )
|
||||
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
||||
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
||||
"""
|
||||
num_seqs, vocab_size = logits.shape
|
||||
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
||||
vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
|
||||
1, vocab_size)
|
||||
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
|
||||
repetition_penalties, 1.0)[logits > 0]
|
||||
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
|
||||
repetition_penalties, 1.0)[logits <= 0]
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
||||
return logits
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,3 +19,13 @@ class SamplingMetadata:
|
||||
generators: Dict[int, torch.Generator]
|
||||
|
||||
max_num_logprobs: int
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: List[List[int]]
|
||||
min_tokens: List[int]
|
||||
stop_token_ids: List[Set[int]]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
@@ -17,9 +19,18 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
_apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
||||
sampling_metadata.stop_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
_apply_penalties(logits, sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids)
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
logits = self.apply_top_k_top_p(logits, sampling_metadata)
|
||||
|
||||
probs = self.get_probs(logits)
|
||||
sampled = self.sample(probs, sampling_metadata)
|
||||
# Use int32 to reduce the tensor size.
|
||||
@@ -157,3 +168,53 @@ def _apply_top_k_top_p(
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_min_token_penalties(logits: torch.Tensor,
|
||||
output_token_ids: List[List[int]],
|
||||
stop_token_ids: List[Set[int]],
|
||||
min_tokens: List[int]):
|
||||
"""
|
||||
Applies minimum token penalty by setting the logits of the stop tokens
|
||||
to -inf.
|
||||
"""
|
||||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
||||
for index, min_token in enumerate(min_tokens):
|
||||
if (len(output_token_ids[index]) < min_token):
|
||||
for stop_token_id in stop_token_ids[index]:
|
||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||
if min_tokens_logits_to_penalize:
|
||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||
|
||||
|
||||
def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: List[List[int]]):
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
|
||||
@@ -43,12 +43,14 @@ class InputBatch:
|
||||
max_num_blocks_per_req: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
|
||||
self.req_id_to_index: Dict[str, int] = {}
|
||||
@@ -63,6 +65,7 @@ class InputBatch:
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Attention-related.
|
||||
self.block_table = torch.zeros(
|
||||
@@ -110,6 +113,50 @@ class InputBatch:
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: Set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = \
|
||||
self.presence_penalties_cpu_tensor.numpy()
|
||||
self.presence_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: Set[str] = set()
|
||||
|
||||
self.min_tokens: List[int] = [0] * max_num_reqs
|
||||
self.stop_token_ids: List[Set[int]] = [
|
||||
set() for _ in range(max_num_reqs)
|
||||
]
|
||||
self.prompt_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
@@ -133,6 +180,7 @@ class InputBatch:
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
@@ -157,6 +205,20 @@ class InputBatch:
|
||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||
if sampling_params.top_k > 0:
|
||||
self.top_k_reqs.add(req_id)
|
||||
self.frequency_penalties_cpu[req_index] = \
|
||||
sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[req_index] = \
|
||||
sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[req_index] = \
|
||||
sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
self.min_tokens[req_index] = sampling_params.min_tokens
|
||||
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
@@ -179,6 +241,9 @@ class InputBatch:
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
@@ -191,6 +256,9 @@ class InputBatch:
|
||||
self.random_reqs.clear()
|
||||
self.top_p_reqs.clear()
|
||||
self.top_k_reqs.clear()
|
||||
self.frequency_penalties_reqs.clear()
|
||||
self.presence_penalties_reqs.clear()
|
||||
self.repetition_penalties_reqs.clear()
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.prompt_logprob_reqs.clear()
|
||||
@@ -224,6 +292,8 @@ class InputBatch:
|
||||
# block_table_cpu.
|
||||
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = \
|
||||
self.num_prompt_tokens[last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table_cpu[empty_index] = self.block_table_cpu[
|
||||
@@ -232,6 +302,15 @@ class InputBatch:
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[empty_index] = \
|
||||
self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[empty_index] = \
|
||||
self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[empty_index] = \
|
||||
self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
|
||||
self.stop_token_ids[empty_index] = \
|
||||
self.stop_token_ids[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
@@ -241,6 +320,7 @@ class InputBatch:
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
req_id_output_token_ids: Dict[str, List[int]],
|
||||
skip_copy: bool = False,
|
||||
) -> SamplingMetadata:
|
||||
if not skip_copy:
|
||||
@@ -250,6 +330,37 @@ class InputBatch:
|
||||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.top_k[:self.num_reqs].copy_(
|
||||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
# penalties to be applied during sampling.
|
||||
self.frequency_penalties[:self.num_reqs].copy_(
|
||||
self.frequency_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
self.presence_penalties[:self.num_reqs].copy_(
|
||||
self.presence_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
self.repetition_penalties[:self.num_reqs].copy_(
|
||||
self.repetition_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
|
||||
output_token_ids: List[List[int]] = []
|
||||
|
||||
for req_id in self.req_ids[:self.num_reqs]:
|
||||
assert req_id is not None
|
||||
# Currently we create a tensor for output_token_ids from scratch
|
||||
# at each step. However, for the penalties computation what we
|
||||
# need is stats about the token ids present in the output. This
|
||||
# stats can be maintained incrementally instead of computing it
|
||||
# from scratch at each step.
|
||||
# TODO - Replace this with incremental update to output token
|
||||
# statistics.
|
||||
output_token_ids.append(req_id_output_token_ids[req_id])
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature[:self.num_reqs],
|
||||
all_greedy=self.all_greedy,
|
||||
@@ -260,8 +371,33 @@ class InputBatch:
|
||||
no_top_k=self.no_top_k,
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:self.num_reqs],
|
||||
presence_penalties=self.presence_penalties[:self.num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:self.num_reqs],
|
||||
output_token_ids=output_token_ids,
|
||||
min_tokens=self.min_tokens[:self.num_reqs],
|
||||
stop_token_ids=self.stop_token_ids[:self.num_reqs],
|
||||
no_penalties=self.no_penalties,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = (
|
||||
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
@@ -282,6 +418,12 @@ class InputBatch:
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> int:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
||||
|
||||
@@ -105,6 +105,7 @@ class GPUModelRunner:
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
@@ -383,7 +384,12 @@ class GPUModelRunner:
|
||||
or scheduler_output.scheduled_resumed_reqs):
|
||||
skip_copy = False
|
||||
# Create the sampling metadata.
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
|
||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
||||
{req_id: req.output_token_ids \
|
||||
for req_id, req in self.requests.items()}
|
||||
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
|
||||
Reference in New Issue
Block a user