[BugFix] Fix use of per-request seed with pipeline parallel (#6698)
This commit is contained in:
@@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
generators: Dict[str, torch.Generator] = {}
|
||||
|
||||
def test_sampling():
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
pin_memory=is_pin_memory_available(),
|
||||
generators=generators)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user