Support logit_bias in v1 Sampler (#13079)
This commit is contained in:
@@ -130,7 +130,7 @@ class InputBatch:
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
@@ -141,8 +141,8 @@ class InputBatch:
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = \
|
||||
self.presence_penalties_cpu_tensor.numpy()
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
@@ -155,7 +155,7 @@ class InputBatch:
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: Set[str] = set()
|
||||
|
||||
self.min_tokens: List[int] = [0] * max_num_reqs
|
||||
@@ -180,6 +180,9 @@ class InputBatch:
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: Dict[str, int] = {}
|
||||
|
||||
self.logit_bias: List[Optional[Dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
@@ -220,16 +223,16 @@ class InputBatch:
|
||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||
if sampling_params.top_k > 0:
|
||||
self.top_k_reqs.add(req_id)
|
||||
self.frequency_penalties_cpu[req_index] = \
|
||||
sampling_params.frequency_penalty
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[req_index] = \
|
||||
sampling_params.presence_penalty
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[req_index] = \
|
||||
sampling_params.repetition_penalty
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
self.min_tokens[req_index] = sampling_params.min_tokens
|
||||
@@ -244,6 +247,8 @@ class InputBatch:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
@@ -284,6 +289,7 @@ class InputBatch:
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
return req_index
|
||||
|
||||
def clear(self) -> None:
|
||||
@@ -302,6 +308,7 @@ class InputBatch:
|
||||
self.request_lora_mapping.fill(0)
|
||||
self.lora_id_to_lora_request.clear()
|
||||
self.lora_id_to_request_ids.clear()
|
||||
self.logit_bias = [None] * self.max_num_reqs
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
if self.num_reqs == 0:
|
||||
@@ -332,8 +339,8 @@ class InputBatch:
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_prompt_tokens[empty_index] = \
|
||||
self.num_prompt_tokens[last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
@@ -341,15 +348,15 @@ class InputBatch:
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[empty_index] = \
|
||||
self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[empty_index] = \
|
||||
self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[empty_index] = \
|
||||
self.repetition_penalties_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[
|
||||
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
|
||||
self.stop_token_ids[empty_index] = \
|
||||
self.stop_token_ids[last_req_index]
|
||||
self.stop_token_ids[empty_index] = self.stop_token_ids[
|
||||
last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
@@ -357,6 +364,8 @@ class InputBatch:
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@@ -378,13 +387,16 @@ class InputBatch:
|
||||
# penalties to be applied during sampling.
|
||||
self.frequency_penalties[:self.num_reqs].copy_(
|
||||
self.frequency_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
non_blocking=True,
|
||||
)
|
||||
self.presence_penalties[:self.num_reqs].copy_(
|
||||
self.presence_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
non_blocking=True,
|
||||
)
|
||||
self.repetition_penalties[:self.num_reqs].copy_(
|
||||
self.repetition_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
non_blocking=True,
|
||||
)
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
@@ -421,6 +433,7 @@ class InputBatch:
|
||||
min_tokens=self.min_tokens[:self.num_reqs],
|
||||
stop_token_ids=self.stop_token_ids[:self.num_reqs],
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:self.num_reqs],
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
@@ -429,10 +442,11 @@ class InputBatch:
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory)
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = (
|
||||
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
||||
num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
|
||||
Reference in New Issue
Block a user