[Core][Distributed] remove graph mode function (#4818)

This commit is contained in:
youkaichao
2024-05-16 10:59:52 -07:00
committed by GitHub
parent b5853f9963
commit e08188081b
4 changed files with 63 additions and 54 deletions

View File

@@ -5,7 +5,7 @@ import pytest
import torch
from vllm.distributed.communication_op import ( # noqa
graph_mode, tensor_model_parallel_all_reduce)
graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_mode():
with graph_capture():
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)