Remove hard-dependencies of Speculative decode to CUDA workers (#10587)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
def init_tensors(self,
|
||||
device: Union[int, str],
|
||||
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if isinstance(device, int):
|
||||
device = f"{device_type}:{device}"
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
@property
|
||||
def probs_dtype(self):
|
||||
return torch.float32
|
||||
@@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
tensor is [batch_size, k + num_bonus_tokens]
|
||||
"""
|
||||
batch_size, k = substitute_token_ids.shape
|
||||
bonus_token_ids = bonus_token_ids.squeeze()
|
||||
bonus_token_ids = bonus_token_ids.squeeze(-1)
|
||||
# Determine the index of the first False value for each row.
|
||||
limits = (accepted == 0).max(1).indices
|
||||
limits[~(accepted == 0).any(1)] = k
|
||||
|
||||
Reference in New Issue
Block a user