Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -7,7 +7,7 @@ import random
import unittest
from collections.abc import Sequence
from numbers import Number
from typing import Any, NamedTuple, Optional, Union
from typing import Any, NamedTuple
import pytest
import torch
@@ -96,10 +96,10 @@ class PackedQKVInputs(NamedTuple):
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_start_loc_list: Optional[list[int]]
kv_start_loc_list: Optional[list[int]]
q_seq_lens: Optional[list[int]]
kv_seq_lens: Optional[list[int]]
q_start_loc_list: list[int] | None
kv_start_loc_list: list[int] | None
q_seq_lens: list[int] | None
kv_seq_lens: list[int] | None
class PackedQKVO(NamedTuple):
@@ -115,7 +115,7 @@ class PackedQKVO(NamedTuple):
x head_size) known-correct attention output
"""
packed_qkv: Optional[PackedQKVInputs]
packed_qkv: PackedQKVInputs | None
ideal_output: torch.Tensor
@@ -149,12 +149,12 @@ class PhaseTestParameters(NamedTuple):
"""
packed_qkvo: PackedQKVO
kv_mmap: Optional[KVMemoryMap]
kv_mmap: KVMemoryMap | None
def maybe_make_int_tensor(
_list: Optional[list[int]],
device: Union[torch.device, str],
_list: list[int] | None,
device: torch.device | str,
) -> torch.Tensor:
"""
Convert Python int list to a 1D int torch.Tensor on `device`
@@ -170,8 +170,8 @@ def maybe_make_int_tensor(
def maybe_make_long_tensor(
_list: Optional[list[int]],
device: Union[torch.device, str],
_list: list[int] | None,
device: torch.device | str,
) -> torch.Tensor:
"""
Convert Python int list to a 1D long torch.Tensor on `device`
@@ -186,7 +186,7 @@ def maybe_make_long_tensor(
)
def maybe_max(_list: Optional[list]) -> Optional[Number]:
def maybe_max(_list: list | None) -> Number | None:
"""
Returns:
@@ -241,9 +241,9 @@ def ref_masked_attention(
key: torch.Tensor,
value: torch.Tensor,
scale: float,
custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[list] = None,
kv_seq_lens: Optional[list] = None,
custom_mask: torch.Tensor | None = None,
q_seq_lens: list | None = None,
kv_seq_lens: list | None = None,
) -> torch.Tensor:
"""
"Golden" masked attention reference. Supports two types of masking:
@@ -302,11 +302,11 @@ def ref_masked_attention(
def make_qkv(
batch_size: int,
max_q_seq_len: int,
max_kv_seq_len: Optional[int],
max_kv_seq_len: int | None,
num_heads: int,
head_size: int,
device: Union[torch.device, str],
force_kv_seq_lens: Optional[list[int]] = None,
device: torch.device | str,
force_kv_seq_lens: list[int] | None = None,
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False,
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
@@ -436,7 +436,7 @@ def make_qkv(
def pack_tensor(
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str]
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: torch.device | str
) -> tuple[torch.Tensor, list[int]]:
"""
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
@@ -470,7 +470,7 @@ def pack_tensor(
return packed_tensor, start_loc_list
def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs:
def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
"""
Individually pack each of Q, K and V, each with dimensions batch_size x
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
@@ -594,19 +594,19 @@ def make_alibi_bias(
def _make_metadata_tensors(
seq_lens: Optional[list[int]],
context_lens: Optional[list[int]],
encoder_seq_lens: Optional[list[int]],
device: Union[torch.device, str],
seq_lens: list[int] | None,
context_lens: list[int] | None,
encoder_seq_lens: list[int] | None,
device: torch.device | str,
) -> tuple[
torch.Tensor,
torch.Tensor,
Any,
Any,
Optional[torch.Tensor],
torch.Tensor | None,
torch.Tensor,
torch.Tensor,
Optional[int],
int | None,
]:
"""
Build scalar & tensor values required to build attention metadata structure.
@@ -678,7 +678,7 @@ def make_kv_cache(
num_heads: int,
head_size: int,
block_size: int,
device: Union[torch.device, str],
device: torch.device | str,
backend: str,
default_val: float = 0.0,
) -> torch.Tensor:
@@ -726,18 +726,18 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
return (num_tokens + block_size) // block_size
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
def make_empty_slot_mapping_tensor(device: torch.device | str):
return maybe_make_long_tensor([], device)
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
def make_empty_block_tables_tensor(device: torch.device | str):
return torch.tensor([], device=device)
def split_slot_mapping(
slot_mapping_list: torch.Tensor,
seq_lens: list[int],
device: Union[torch.device, str],
device: torch.device | str,
):
"""
Split a slot mapping into valid prefill- and decode-phase slot mappings.
@@ -799,7 +799,7 @@ def split_slot_mapping(
def make_block_tables_slot_mapping(
block_size: int,
seq_lens: list[int],
device: Union[torch.device, str],
device: torch.device | str,
block_base_addr: int = 0,
) -> tuple[torch.Tensor, list[int], int]:
"""
@@ -880,11 +880,11 @@ def make_block_tables_slot_mapping(
def make_test_metadata(
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[list[int]],
decoder_test_params: Optional[PhaseTestParameters],
device: Union[torch.device, str],
encoder_test_params: Optional[PhaseTestParameters] = None,
cross_test_params: Optional[PhaseTestParameters] = None,
seq_lens: list[int] | None,
decoder_test_params: PhaseTestParameters | None,
device: torch.device | str,
encoder_test_params: PhaseTestParameters | None = None,
cross_test_params: PhaseTestParameters | None = None,
) -> AttentionMetadata:
"""
Construct fake attention metadata for a given test phase
@@ -1142,16 +1142,16 @@ def torch_experts(
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
b_bias1: torch.Tensor | None = None,
b_bias2: torch.Tensor | None = None,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
quant_dtype: torch.dtype | None = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
apply_router_weights_on_input: bool = False,
) -> torch.Tensor:
assert (
@@ -1261,10 +1261,10 @@ def torch_moe(
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
b_bias1: torch.Tensor | None = None,
b_bias2: torch.Tensor | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
@@ -1298,15 +1298,13 @@ def torch_moe_single(a, w, score, topk):
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(
op: Union[
torch._ops.OpOverload,
torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef,
],
op: torch._ops.OpOverload
| torch._ops.OpOverloadPacket
| torch._library.custom_ops.CustomOpDef,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
test_utils: str | Sequence[str] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True,
) -> dict[str, str]:
@@ -1338,7 +1336,7 @@ def baseline_scaled_mm(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match