[Misc] Misc code simplifications (#26450)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-10-09 02:10:06 -07:00
committed by GitHub
parent a83ff278d6
commit ddcbc2f334
6 changed files with 78 additions and 89 deletions

View File

@@ -147,22 +147,20 @@ class RejectionSampler(nn.Module):
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
) -> torch.Tensor:
has_penalties = not sampling_metadata.no_penalties
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
sampling_metadata.bad_words_token_ids or has_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if any_penalties_or_bad_words:
output_token_ids = self._combine_outputs_with_spec_tokens(
sampling_metadata.output_token_ids,
output_token_ids,
sampling_metadata.spec_token_ids,
)
# Calculate indices of target logits.
if (
sampling_metadata.allowed_token_ids_mask is not None
or not sampling_metadata.no_penalties
):
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
num_requests = len(sampling_metadata.output_token_ids)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
@@ -180,18 +178,15 @@ class RejectionSampler(nn.Module):
logits.masked_fill_(token_mask, float("-inf"))
# Apply bad words exclusion.
if sampling_metadata.bad_words_token_ids:
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
apply_bad_words_with_drafts(
logits,
sampling_metadata.bad_words_token_ids,
output_token_ids,
metadata.num_draft_tokens,
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
)
return logits
@staticmethod
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
@@ -218,8 +213,8 @@ class RejectionSampler(nn.Module):
)
return logits
@staticmethod
def _combine_outputs_with_spec_tokens(
self,
output_token_ids: list[list[int]],
spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]: