[mypy] Misc. typing improvements (#7417)

This commit is contained in:
Cyrus Leung
2024-08-13 09:20:20 +08:00
committed by GitHub
parent 198d6a2898
commit 9ba85bc152
16 changed files with 74 additions and 75 deletions

View File

@@ -1,10 +1,12 @@
import contextlib
import functools
import gc
from typing import Callable, TypeVar
import pytest
import ray
import torch
from typing_extensions import ParamSpec
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
@@ -22,12 +24,16 @@ def cleanup():
torch.cuda.empty_cache()
def retry_until_skip(n):
_P = ParamSpec("_P")
_R = TypeVar("_R")
def decorator_retry(func):
def retry_until_skip(n: int):
def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(func)
def wrapper_retry(*args, **kwargs):
def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
for i in range(n):
try:
return func(*args, **kwargs)
@@ -35,7 +41,9 @@ def retry_until_skip(n):
gc.collect()
torch.cuda.empty_cache()
if i == n - 1:
pytest.skip("Skipping test after attempts..")
pytest.skip(f"Skipping test after {n} attempts.")
raise AssertionError("Code should not be reached")
return wrapper_retry