[V1][Core] min_p sampling support (#13191)
Signed-off-by: Aoyu <aoyuzhan@amazon.com> Co-authored-by: Aoyu <aoyuzhan@amazon.com>
This commit is contained in:
@@ -14,6 +14,8 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
|
||||
@@ -120,6 +122,16 @@ class InputBatch:
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: Set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: Set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
@@ -223,8 +235,11 @@ class InputBatch:
|
||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||
if sampling_params.top_k > 0:
|
||||
self.top_k_reqs.add(req_id)
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
@@ -273,6 +288,7 @@ class InputBatch:
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
@@ -299,6 +315,7 @@ class InputBatch:
|
||||
self.random_reqs.clear()
|
||||
self.top_p_reqs.clear()
|
||||
self.top_k_reqs.clear()
|
||||
self.min_p_reqs.clear()
|
||||
self.frequency_penalties_reqs.clear()
|
||||
self.presence_penalties_reqs.clear()
|
||||
self.repetition_penalties_reqs.clear()
|
||||
@@ -354,6 +371,7 @@ class InputBatch:
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
|
||||
self.stop_token_ids[empty_index] = self.stop_token_ids[
|
||||
last_req_index]
|
||||
@@ -381,6 +399,8 @@ class InputBatch:
|
||||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.top_k[:self.num_reqs].copy_(
|
||||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.min_p[:self.num_reqs].copy_(
|
||||
self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
@@ -421,6 +441,8 @@ class InputBatch:
|
||||
all_random=self.all_random,
|
||||
top_p=self.top_p[:self.num_reqs],
|
||||
top_k=self.top_k[:self.num_reqs],
|
||||
min_p=self.min_p[:self.num_reqs],
|
||||
no_min_p=self.no_min_p,
|
||||
no_top_p=self.no_top_p,
|
||||
no_top_k=self.no_top_k,
|
||||
generators=self.generators,
|
||||
@@ -497,6 +519,10 @@ class InputBatch:
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
|
||||
Reference in New Issue
Block a user