[BugFix] Fix use of per-request seed with pipeline parallel (#6698)

This commit is contained in:
Nick Hill
2024-07-30 10:40:08 -07:00
committed by GitHub
parent f058403683
commit 5cf9254a9c
21 changed files with 222 additions and 137 deletions

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional
from typing import Dict, Optional
import torch
import torch.jit
@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
raise NotImplementedError