- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
16 lines
306 B
Python
16 lines
306 B
Python
import torch
|
|
import torch.multiprocessing as mp
|
|
import deep_gemm
|
|
|
|
|
|
def main(local_rank: int):
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
procs = [mp.Process(target=main, args=(i, ), ) for i in range(8)]
|
|
for p in procs:
|
|
p.start()
|
|
for p in procs:
|
|
p.join()
|