[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SpecDecodeBaseSampler():
|
||||
class SpecDecodeBaseSampler(nn.Module):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step.
|
||||
"""
|
||||
@@ -51,6 +54,16 @@ class SpecDecodeBaseSampler():
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_output(
|
||||
self,
|
||||
accepted: torch.Tensor, # [batch_size, k]
|
||||
|
||||
Reference in New Issue
Block a user