[Bugfix][V1] Fix allowed_token_ids for v1 Sampler (#14169)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang
2025-03-05 00:49:44 -08:00
committed by GitHub
parent ec79b67c77
commit 8d6cd32b7b
2 changed files with 12 additions and 4 deletions

View File

@@ -199,6 +199,8 @@ class InputBatch:
self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
@@ -300,6 +302,7 @@ class InputBatch:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
@@ -309,8 +312,10 @@ class InputBatch:
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = True
sampling_params.allowed_token_ids] = False
# Add request lora ID
if request.lora_request:
@@ -359,6 +364,7 @@ class InputBatch:
self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
return req_index