Support logit_bias in v1 Sampler (#13079)

This commit is contained in:
Lu Fang
2025-02-14 04:34:59 -08:00
committed by GitHub
parent 085b7b2d6c
commit 6224a9f620
6 changed files with 200 additions and 101 deletions

View File

@@ -130,7 +130,7 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set()
# Presence penalty related data structures
@@ -141,8 +141,8 @@ class InputBatch:
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = \
self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: Set[str] = set()
# Repetition penalty related data structures
@@ -155,7 +155,7 @@ class InputBatch:
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set()
self.min_tokens: List[int] = [0] * max_num_reqs
@@ -180,6 +180,9 @@ class InputBatch:
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}
self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs
def add_request(
self,
request: "CachedRequestState",
@@ -220,16 +223,16 @@ class InputBatch:
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id)
self.frequency_penalties_cpu[req_index] = \
sampling_params.frequency_penalty
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = \
sampling_params.presence_penalty
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = \
sampling_params.repetition_penalty
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens
@@ -244,6 +247,8 @@ class InputBatch:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
# Add request lora ID
if request.lora_request:
@@ -284,6 +289,7 @@ class InputBatch:
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
return req_index
def clear(self) -> None:
@@ -302,6 +308,7 @@ class InputBatch:
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
self.logit_bias = [None] * self.max_num_reqs
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
@@ -332,8 +339,8 @@ class InputBatch:
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_prompt_tokens[empty_index] = \
self.num_prompt_tokens[last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
@@ -341,15 +348,15 @@ class InputBatch:
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = \
self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[empty_index] = \
self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[empty_index] = \
self.repetition_penalties_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = \
self.stop_token_ids[last_req_index]
self.stop_token_ids[empty_index] = self.stop_token_ids[
last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
@@ -357,6 +364,8 @@ class InputBatch:
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
@@ -378,13 +387,16 @@ class InputBatch:
# penalties to be applied during sampling.
self.frequency_penalties[:self.num_reqs].copy_(
self.frequency_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
non_blocking=True,
)
self.presence_penalties[:self.num_reqs].copy_(
self.presence_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
non_blocking=True,
)
self.repetition_penalties[:self.num_reqs].copy_(
self.repetition_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
non_blocking=True,
)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
@@ -421,6 +433,7 @@ class InputBatch:
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:self.num_reqs],
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
@@ -429,10 +442,11 @@ class InputBatch:
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory)
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = (
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):