[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)
Co-authored-by: Cade Daniel <edacih@gmail.com> Co-authored-by: Cade Daniel <cade@anyscale.com>
This commit is contained in:
@@ -219,16 +219,16 @@ def broadcast_tensor_dict(
|
||||
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
|
||||
dtypes).
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if (not torch.distributed.is_initialized()
|
||||
or torch.distributed.get_world_size(group=group) == 1):
|
||||
return tensor_dict
|
||||
|
||||
group = group or torch.distributed.group.WORLD
|
||||
metadata_group = metadata_group or get_cpu_world_group()
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
if rank == src:
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
|
||||
Reference in New Issue
Block a user