39 lines
1012 B
Python
39 lines
1012 B
Python
|
|
import functools
|
||
|
|
import os
|
||
|
|
import torch
|
||
|
|
from typing import Callable
|
||
|
|
|
||
|
|
def get_arch_major() -> int:
|
||
|
|
major, minor = torch.cuda.get_device_capability()
|
||
|
|
return major
|
||
|
|
|
||
|
|
|
||
|
|
def test_filter(condition: Callable):
|
||
|
|
def decorator(func):
|
||
|
|
@functools.wraps(func)
|
||
|
|
def wrapper(*args, **kwargs):
|
||
|
|
if condition():
|
||
|
|
func(*args, **kwargs)
|
||
|
|
else:
|
||
|
|
print(f'{func.__name__}:')
|
||
|
|
print(f' > Filtered by {condition}')
|
||
|
|
print()
|
||
|
|
return wrapper
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
|
||
|
|
def ignore_env(name: str, condition: Callable):
|
||
|
|
def decorator(func):
|
||
|
|
@functools.wraps(func)
|
||
|
|
def wrapper(*args, **kwargs):
|
||
|
|
if condition():
|
||
|
|
saved = os.environ.pop(name, None)
|
||
|
|
func(*args, **kwargs)
|
||
|
|
if saved is not None:
|
||
|
|
os.environ[name] = saved
|
||
|
|
else:
|
||
|
|
func(*args, **kwargs)
|
||
|
|
|
||
|
|
return wrapper
|
||
|
|
return decorator
|