[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)

This commit is contained in:
sroy745
2024-07-01 00:33:05 -07:00
committed by GitHub
parent 614aa51203
commit 80ca1e6a3a
14 changed files with 480 additions and 208 deletions

View File

@@ -1,9 +1,12 @@
from abc import abstractmethod
from typing import Optional
import torch
import torch.jit
import torch.nn as nn
class SpecDecodeBaseSampler():
class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification
step.
"""
@@ -51,6 +54,16 @@ class SpecDecodeBaseSampler():
def token_id_dtype(self):
return torch.int64
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]