[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)
This commit is contained in:
@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
|
||||
def _raise_if_incorrect_input(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self._raise_if_incorrect_shape(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_inconsistent_device(target_probs, draft_token_ids,
|
||||
bonus_token_ids, draft_probs)
|
||||
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
|
||||
self._raise_if_incorrect_shape(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_incorrect_dtype(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_inconsistent_device(target_with_bonus_probs,
|
||||
draft_token_ids, bonus_token_ids,
|
||||
draft_probs)
|
||||
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
|
||||
draft_token_ids, bonus_token_ids)
|
||||
|
||||
def _raise_if_incorrect_shape(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
(target_batch_size, num_target_probs,
|
||||
target_vocab_size) = target_probs.shape
|
||||
target_vocab_size) = target_with_bonus_probs.shape
|
||||
|
||||
# Does not count the extra token
|
||||
num_target_probs -= 1
|
||||
|
||||
# validate the shape of draft token ids.
|
||||
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
||||
@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
|
||||
def _raise_if_incorrect_dtype(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
assert target_probs.dtype == self.probs_dtype
|
||||
assert target_with_bonus_probs.dtype == self.probs_dtype
|
||||
assert draft_token_ids.dtype == self.token_id_dtype
|
||||
assert bonus_token_ids.dtype == self.token_id_dtype
|
||||
if draft_probs is not None:
|
||||
@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
|
||||
def _raise_if_inconsistent_device(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
devices = [
|
||||
t.device for t in
|
||||
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
|
||||
if t is not None
|
||||
t.device for t in [
|
||||
target_with_bonus_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids
|
||||
] if t is not None
|
||||
]
|
||||
assert all([devices[0] == device for device in devices])
|
||||
|
||||
@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
target_with_bonus_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user