[SpecDecode] [Minor] Fix spec decode sampler tests (#7183)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
||||
self.num_draft_tokens: int = 0
|
||||
|
||||
def init_gpu_tensors(self, rank: int) -> None:
|
||||
def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
device = f"cuda:{rank}"
|
||||
if isinstance(device, int):
|
||||
device = f"cuda:{device}"
|
||||
elif not isinstance(device, str):
|
||||
raise ValueError(f"Device must be int or str, get {type(device)}")
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
Reference in New Issue
Block a user