Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,8 +8,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||
SimpleBuffer)
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
||||
|
||||
# TODO: the test depends on a lot of fields in the current implementation.
|
||||
@@ -17,7 +16,6 @@ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
||||
|
||||
|
||||
def test_run(my_rank, buffer, device):
|
||||
|
||||
# buffer should be empty in the beginning
|
||||
if my_rank == 0:
|
||||
assert buffer.buffer_size == 0
|
||||
@@ -27,7 +25,7 @@ def test_run(my_rank, buffer, device):
|
||||
|
||||
# insert
|
||||
tokens = torch.tensor([1, 2, 3]).to(device)
|
||||
roi = (tokens > 0)
|
||||
roi = tokens > 0
|
||||
if my_rank == 0:
|
||||
key = 2.0 * torch.ones([5, 6]).to(device)
|
||||
value = 3.0 * torch.ones([5, 6]).to(device)
|
||||
@@ -55,7 +53,6 @@ def test_run(my_rank, buffer, device):
|
||||
|
||||
|
||||
def stress_test(my_rank, buf, device):
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.manual_seed(100)
|
||||
|
||||
@@ -66,7 +63,8 @@ def stress_test(my_rank, buf, device):
|
||||
torch.rand(100).to(device), # key
|
||||
torch.rand(100).to(device), # value
|
||||
torch.rand(100).to(device), # hidden
|
||||
) for i in tqdm(range(200))
|
||||
)
|
||||
for i in tqdm(range(200))
|
||||
]
|
||||
|
||||
random.seed(my_rank)
|
||||
@@ -115,12 +113,11 @@ def stress_test(my_rank, buf, device):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
my_rank = int(os.environ['RANK'])
|
||||
my_rank = int(os.environ["RANK"])
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend='gloo',
|
||||
init_method='tcp://localhost:12398',
|
||||
backend="gloo",
|
||||
init_method="tcp://localhost:12398",
|
||||
world_size=2,
|
||||
rank=my_rank,
|
||||
)
|
||||
@@ -128,8 +125,8 @@ if __name__ == "__main__":
|
||||
print(f"initialized! My rank is {my_rank}")
|
||||
|
||||
config = KVTransferConfig(
|
||||
kv_connector='P2pNcclConnector',
|
||||
kv_buffer_device='cuda',
|
||||
kv_connector="P2pNcclConnector",
|
||||
kv_buffer_device="cuda",
|
||||
kv_buffer_size=1e9,
|
||||
kv_rank=my_rank,
|
||||
kv_role="kv_both", # this arg doesn't matter in this test
|
||||
@@ -160,4 +157,4 @@ if __name__ == "__main__":
|
||||
buffer.close()
|
||||
data_pipe.close()
|
||||
cpu_pipe.close()
|
||||
print('Done')
|
||||
print("Done")
|
||||
|
||||
Reference in New Issue
Block a user