[Model Runner V2] Minor refactor for logit_bias (#32209)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-12 13:08:30 -08:00
committed by GitHub
parent 9f430c94bd
commit dec28688c5

View File

@@ -119,35 +119,18 @@ class LogitBiasState:
idx_mapping: torch.Tensor,
pos: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
MAX_NUM_ALLOWED_TOKEN_IDS,
MAX_NUM_LOGIT_BIAS_TOKENS,
MAX_NUM_STOP_TOKEN_IDS,
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
apply_logit_bias(
logits,
logits.stride(0),
vocab_size,
idx_mapping,
self.num_allowed_token_ids,
self.allowed_token_ids,
self.allowed_token_ids.gpu.stride(0),
self.num_logit_bias,
self.logit_bias_token_ids,
self.logit_bias_token_ids.gpu.stride(0),
self.logit_bias,
self.logit_bias.gpu.stride(0),
pos,
self.min_lens,
self.num_stop_token_ids,
self.stop_token_ids,
self.stop_token_ids.gpu.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
self.num_logit_bias.gpu,
self.logit_bias_token_ids.gpu,
self.logit_bias.gpu,
self.min_lens.gpu,
self.num_stop_token_ids.gpu,
self.stop_token_ids.gpu,
)
@@ -240,3 +223,48 @@ def _bias_kernel(
-float("inf"),
mask=mask,
)
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
num_logit_bias: torch.Tensor,
logit_bias_token_ids: torch.Tensor,
logit_bias: torch.Tensor,
min_lens: torch.Tensor,
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
logit_bias_token_ids.shape[-1],
stop_token_ids.shape[-1],
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),
num_logit_bias,
logit_bias_token_ids,
logit_bias_token_ids.stride(0),
logit_bias,
logit_bias.stride(0),
pos,
min_lens,
num_stop_token_ids,
stop_token_ids,
stop_token_ids.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
)