[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)

This commit is contained in:
Lily Liu
2024-09-01 21:23:29 -07:00
committed by GitHub
parent f8d60145b4
commit e6a26ed037
9 changed files with 306 additions and 109 deletions

View File

@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_incorrect_input(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
self._raise_if_incorrect_shape(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_inconsistent_device(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
self._raise_if_incorrect_shape(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_incorrect_dtype(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_inconsistent_device(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
draft_token_ids, bonus_token_ids)
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
target_vocab_size) = target_with_bonus_probs.shape
# Does not count the extra token
num_target_probs -= 1
# validate the shape of draft token ids.
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
assert target_probs.dtype == self.probs_dtype
assert target_with_bonus_probs.dtype == self.probs_dtype
assert draft_token_ids.dtype == self.token_id_dtype
assert bonus_token_ids.dtype == self.token_id_dtype
if draft_probs is not None:
@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
if t is not None
t.device for t in [
target_with_bonus_probs, bonus_token_ids, draft_probs,
draft_token_ids
] if t is not None
]
assert all([devices[0] == device for device in devices])
@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
target_with_bonus_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,