2026-04-17 09:45:14 +08:00
|
|
|
import argparse
|
2025-08-15 18:32:35 +08:00
|
|
|
import torch
|
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
|
import deep_gemm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(local_rank: int):
|
|
|
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2026-04-17 09:45:14 +08:00
|
|
|
parser = argparse.ArgumentParser(description='Test lazy initialization')
|
|
|
|
|
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
procs = [mp.Process(target=main, args=(i, ), ) for i in range(args.num_processes)]
|
2025-08-15 18:32:35 +08:00
|
|
|
for p in procs:
|
|
|
|
|
p.start()
|
|
|
|
|
for p in procs:
|
|
|
|
|
p.join()
|