[Bugfix] Make spec. decode respect per-request seed. (#6034)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Thomas Parnell
2024-07-19 04:22:08 +02:00
committed by GitHub
parent b5672a112c
commit d4201e06d5
8 changed files with 293 additions and 46 deletions

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional
from typing import List, Optional
import torch
import torch.jit
@@ -54,16 +54,6 @@ class SpecDecodeBaseSampler(nn.Module):
def token_id_dtype(self):
return torch.int64
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
@@ -217,3 +207,36 @@ class SpecDecodeBaseSampler(nn.Module):
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
"""Base class for samplers used for Speculative Decoding verification
step which are deterministic.
"""
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
"""Base class for samplers used for Speculative Decoding verification
step which are stochastic
"""
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
) -> torch.Tensor:
raise NotImplementedError