Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -4,9 +4,9 @@
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from numbers import Number
|
||||
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
||||
Type, Union)
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -20,13 +20,13 @@ from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||
|
||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||
# bugs related to this test in PyTorch 2.4.
|
||||
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
||||
DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
|
||||
"test_schema",
|
||||
"test_autograd_registration",
|
||||
"test_faketensor",
|
||||
)
|
||||
|
||||
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
||||
ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
|
||||
"test_schema",
|
||||
"test_autograd_registration",
|
||||
"test_faketensor",
|
||||
@@ -50,8 +50,8 @@ class QKVInputs(NamedTuple):
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
q_seq_lens: List[int]
|
||||
kv_seq_lens: List[int]
|
||||
q_seq_lens: list[int]
|
||||
kv_seq_lens: list[int]
|
||||
|
||||
|
||||
class QKVO(NamedTuple):
|
||||
@@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple):
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
q_start_loc_list: Optional[List[int]]
|
||||
kv_start_loc_list: Optional[List[int]]
|
||||
q_seq_lens: Optional[List[int]]
|
||||
kv_seq_lens: Optional[List[int]]
|
||||
q_start_loc_list: Optional[list[int]]
|
||||
kv_start_loc_list: Optional[list[int]]
|
||||
q_seq_lens: Optional[list[int]]
|
||||
kv_seq_lens: Optional[list[int]]
|
||||
|
||||
|
||||
class PackedQKVO(NamedTuple):
|
||||
@@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple):
|
||||
|
||||
|
||||
def maybe_make_int_tensor(
|
||||
_list: Optional[List[int]],
|
||||
_list: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
@@ -162,7 +162,7 @@ def maybe_make_int_tensor(
|
||||
|
||||
|
||||
def maybe_make_long_tensor(
|
||||
_list: Optional[List[int]],
|
||||
_list: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
@@ -177,7 +177,7 @@ def maybe_make_long_tensor(
|
||||
_list, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
def maybe_max(_list: Optional[List]) -> Optional[Number]:
|
||||
def maybe_max(_list: Optional[list]) -> Optional[Number]:
|
||||
'''
|
||||
Returns:
|
||||
|
||||
@@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
custom_mask: Optional[torch.Tensor] = None,
|
||||
q_seq_lens: Optional[List] = None,
|
||||
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
|
||||
q_seq_lens: Optional[list] = None,
|
||||
kv_seq_lens: Optional[list] = None) -> torch.Tensor:
|
||||
'''
|
||||
"Golden" masked attention reference. Supports two types of masking:
|
||||
|
||||
@@ -295,10 +295,10 @@ def make_qkv(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
device: Union[torch.device, str],
|
||||
force_kv_seq_lens: Optional[List[int]] = None,
|
||||
force_kv_seq_lens: Optional[list[int]] = None,
|
||||
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
||||
force_max_len: bool = False,
|
||||
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
|
||||
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
|
||||
'''
|
||||
Construct QKV test tensors for self- and cross-attention.
|
||||
|
||||
@@ -429,8 +429,8 @@ def make_qkv(
|
||||
|
||||
|
||||
def pack_tensor(
|
||||
unpacked_tensor: torch.Tensor, seq_lens: List[int],
|
||||
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
|
||||
unpacked_tensor: torch.Tensor, seq_lens: list[int],
|
||||
device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]:
|
||||
'''
|
||||
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
||||
unpadded number_of_tokens x num_heads x head_size tensor, where
|
||||
@@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
||||
|
||||
|
||||
def _make_metadata_tensors(
|
||||
seq_lens: Optional[List[int]],
|
||||
context_lens: Optional[List[int]],
|
||||
encoder_seq_lens: Optional[List[int]],
|
||||
seq_lens: Optional[list[int]],
|
||||
context_lens: Optional[list[int]],
|
||||
encoder_seq_lens: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
||||
torch.Tensor, torch.Tensor, Optional[int]]:
|
||||
'''
|
||||
Build scalar & tensor values required to build attention metadata structure.
|
||||
@@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]):
|
||||
return torch.tensor([], device=device)
|
||||
|
||||
|
||||
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
||||
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
|
||||
device: Union[torch.device, str]):
|
||||
'''
|
||||
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
||||
@@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
||||
|
||||
Arguments:
|
||||
|
||||
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
|
||||
* slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
|
||||
post-decode sequences
|
||||
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
|
||||
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
|
||||
description above)
|
||||
* device: cuda, cpu, etc.
|
||||
|
||||
@@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
||||
|
||||
def make_block_tables_slot_mapping(
|
||||
block_size: int,
|
||||
seq_lens: List[int],
|
||||
seq_lens: list[int],
|
||||
device: Union[torch.device, str],
|
||||
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
|
||||
block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]:
|
||||
'''
|
||||
Construct fake block tables & slot mappings.
|
||||
|
||||
@@ -794,7 +794,7 @@ def make_block_tables_slot_mapping(
|
||||
def make_test_metadata(
|
||||
attn_backend: _Backend,
|
||||
is_prompt: bool,
|
||||
seq_lens: Optional[List[int]],
|
||||
seq_lens: Optional[list[int]],
|
||||
decoder_test_params: Optional[PhaseTestParameters],
|
||||
device: Union[torch.device, str],
|
||||
encoder_test_params: Optional[PhaseTestParameters] = None,
|
||||
@@ -1043,7 +1043,7 @@ def fp8_allclose(
|
||||
# Marlin MoE test utils
|
||||
|
||||
|
||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||
def stack_and_dev(tensors: list[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
@@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk):
|
||||
# and a patched version of allclose that supports fp8 types.
|
||||
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
torch._library.custom_ops.CustomOpDef],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
cond: bool = True) -> Dict[str, str]:
|
||||
cond: bool = True) -> dict[str, str]:
|
||||
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
|
||||
return torch.library.opcheck(
|
||||
op,
|
||||
@@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||
|
||||
Reference in New Issue
Block a user