[mypy] Misc. typing improvements (#7417)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
@@ -22,12 +24,16 @@ def cleanup():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def retry_until_skip(n):
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
def decorator_retry(func):
|
||||
|
||||
def retry_until_skip(n: int):
|
||||
|
||||
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper_retry(*args, **kwargs):
|
||||
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
for i in range(n):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -35,7 +41,9 @@ def retry_until_skip(n):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if i == n - 1:
|
||||
pytest.skip("Skipping test after attempts..")
|
||||
pytest.skip(f"Skipping test after {n} attempts.")
|
||||
|
||||
raise AssertionError("Code should not be reached")
|
||||
|
||||
return wrapper_retry
|
||||
|
||||
|
||||
Reference in New Issue
Block a user