[Bugfix] Make spec. decode respect per-request seed. (#6034)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@@ -54,16 +54,6 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_output(
|
||||
self,
|
||||
accepted: torch.Tensor, # [batch_size, k]
|
||||
@@ -217,3 +207,36 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
assert torch.all(bonus_token_ids >= 0)
|
||||
assert torch.all(draft_token_ids < vocab_size)
|
||||
assert torch.all(draft_token_ids >= 0)
|
||||
|
||||
|
||||
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step which are deterministic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step which are stochastic
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user