Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -4,9 +4,9 @@
import itertools
import random
import unittest
from collections.abc import Sequence
from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Type, Union)
from typing import Any, NamedTuple, Optional, Union
import pytest
import torch
@@ -20,13 +20,13 @@ from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
)
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
@@ -50,8 +50,8 @@ class QKVInputs(NamedTuple):
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_seq_lens: List[int]
kv_seq_lens: List[int]
q_seq_lens: list[int]
kv_seq_lens: list[int]
class QKVO(NamedTuple):
@@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple):
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_start_loc_list: Optional[List[int]]
kv_start_loc_list: Optional[List[int]]
q_seq_lens: Optional[List[int]]
kv_seq_lens: Optional[List[int]]
q_start_loc_list: Optional[list[int]]
kv_start_loc_list: Optional[list[int]]
q_seq_lens: Optional[list[int]]
kv_seq_lens: Optional[list[int]]
class PackedQKVO(NamedTuple):
@@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple):
def maybe_make_int_tensor(
_list: Optional[List[int]],
_list: Optional[list[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
@@ -162,7 +162,7 @@ def maybe_make_int_tensor(
def maybe_make_long_tensor(
_list: Optional[List[int]],
_list: Optional[list[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
@@ -177,7 +177,7 @@ def maybe_make_long_tensor(
_list, dtype=torch.long, device=device)
def maybe_max(_list: Optional[List]) -> Optional[Number]:
def maybe_max(_list: Optional[list]) -> Optional[Number]:
'''
Returns:
@@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor,
value: torch.Tensor,
scale: float,
custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[List] = None,
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
q_seq_lens: Optional[list] = None,
kv_seq_lens: Optional[list] = None) -> torch.Tensor:
'''
"Golden" masked attention reference. Supports two types of masking:
@@ -295,10 +295,10 @@ def make_qkv(
num_heads: int,
head_size: int,
device: Union[torch.device, str],
force_kv_seq_lens: Optional[List[int]] = None,
force_kv_seq_lens: Optional[list[int]] = None,
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False,
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
'''
Construct QKV test tensors for self- and cross-attention.
@@ -429,8 +429,8 @@ def make_qkv(
def pack_tensor(
unpacked_tensor: torch.Tensor, seq_lens: List[int],
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
unpacked_tensor: torch.Tensor, seq_lens: list[int],
device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
@@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend:
def _make_metadata_tensors(
seq_lens: Optional[List[int]],
context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]],
seq_lens: Optional[list[int]],
context_lens: Optional[list[int]],
encoder_seq_lens: Optional[list[int]],
device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
torch.Tensor, torch.Tensor, Optional[int]]:
'''
Build scalar & tensor values required to build attention metadata structure.
@@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]):
return torch.tensor([], device=device)
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
device: Union[torch.device, str]):
'''
Split a slot mapping into valid prefill- and decode-phase slot mappings.
@@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
* slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
post-decode sequences
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
description above)
* device: cuda, cpu, etc.
@@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
def make_block_tables_slot_mapping(
block_size: int,
seq_lens: List[int],
seq_lens: list[int],
device: Union[torch.device, str],
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]:
'''
Construct fake block tables & slot mappings.
@@ -794,7 +794,7 @@ def make_block_tables_slot_mapping(
def make_test_metadata(
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[List[int]],
seq_lens: Optional[list[int]],
decoder_test_params: Optional[PhaseTestParameters],
device: Union[torch.device, str],
encoder_test_params: Optional[PhaseTestParameters] = None,
@@ -1043,7 +1043,7 @@ def fp8_allclose(
# Marlin MoE test utils
def stack_and_dev(tensors: List[torch.Tensor]):
def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
@@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk):
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]:
cond: bool = True) -> dict[str, str]:
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
return torch.library.opcheck(
op,
@@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting