[Model Runner V2] Minor refactor for logit_bias (#32209)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -119,35 +119,18 @@ class LogitBiasState:
|
||||
idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
MAX_NUM_ALLOWED_TOKEN_IDS,
|
||||
MAX_NUM_LOGIT_BIAS_TOKENS,
|
||||
MAX_NUM_STOP_TOKEN_IDS,
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
self.num_allowed_token_ids,
|
||||
self.allowed_token_ids,
|
||||
self.allowed_token_ids.gpu.stride(0),
|
||||
self.num_logit_bias,
|
||||
self.logit_bias_token_ids,
|
||||
self.logit_bias_token_ids.gpu.stride(0),
|
||||
self.logit_bias,
|
||||
self.logit_bias.gpu.stride(0),
|
||||
pos,
|
||||
self.min_lens,
|
||||
self.num_stop_token_ids,
|
||||
self.stop_token_ids,
|
||||
self.stop_token_ids.gpu.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
self.num_logit_bias.gpu,
|
||||
self.logit_bias_token_ids.gpu,
|
||||
self.logit_bias.gpu,
|
||||
self.min_lens.gpu,
|
||||
self.num_stop_token_ids.gpu,
|
||||
self.stop_token_ids.gpu,
|
||||
)
|
||||
|
||||
|
||||
@@ -240,3 +223,48 @@ def _bias_kernel(
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
num_logit_bias: torch.Tensor,
|
||||
logit_bias_token_ids: torch.Tensor,
|
||||
logit_bias: torch.Tensor,
|
||||
min_lens: torch.Tensor,
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
logit_bias_token_ids.shape[-1],
|
||||
stop_token_ids.shape[-1],
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
num_logit_bias,
|
||||
logit_bias_token_ids,
|
||||
logit_bias_token_ids.stride(0),
|
||||
logit_bias,
|
||||
logit_bias.stride(0),
|
||||
pos,
|
||||
min_lens,
|
||||
num_stop_token_ids,
|
||||
stop_token_ids,
|
||||
stop_token_ids.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user