[SpecDecode] [Minor] Fix spec decode sampler tests (#7183)

This commit is contained in:
Lily Liu
2024-08-06 10:40:32 -07:00
committed by GitHub
parent 00afc78590
commit 5c60c8c423
3 changed files with 22 additions and 19 deletions

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Union
import torch
import torch.jit
@@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module):
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
if isinstance(device, int):
device = f"cuda:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)