Remove hard-dependencies of Speculative decode to CUDA workers (#10587)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue
2024-11-26 19:57:11 -06:00
committed by GitHub
parent 2f0a0a17a4
commit 0a71900bc9
19 changed files with 219 additions and 77 deletions

View File

@@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
dtype=torch.long,
device=device)
def init_tensors(self,
device: Union[int, str],
device_type: Union[torch.device, str] = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device_type, torch.device):
device_type = device_type.type
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k