Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import random
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from numbers import Number
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -96,10 +96,10 @@ class PackedQKVInputs(NamedTuple):
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
q_start_loc_list: Optional[list[int]]
|
||||
kv_start_loc_list: Optional[list[int]]
|
||||
q_seq_lens: Optional[list[int]]
|
||||
kv_seq_lens: Optional[list[int]]
|
||||
q_start_loc_list: list[int] | None
|
||||
kv_start_loc_list: list[int] | None
|
||||
q_seq_lens: list[int] | None
|
||||
kv_seq_lens: list[int] | None
|
||||
|
||||
|
||||
class PackedQKVO(NamedTuple):
|
||||
@@ -115,7 +115,7 @@ class PackedQKVO(NamedTuple):
|
||||
x head_size) known-correct attention output
|
||||
"""
|
||||
|
||||
packed_qkv: Optional[PackedQKVInputs]
|
||||
packed_qkv: PackedQKVInputs | None
|
||||
ideal_output: torch.Tensor
|
||||
|
||||
|
||||
@@ -149,12 +149,12 @@ class PhaseTestParameters(NamedTuple):
|
||||
"""
|
||||
|
||||
packed_qkvo: PackedQKVO
|
||||
kv_mmap: Optional[KVMemoryMap]
|
||||
kv_mmap: KVMemoryMap | None
|
||||
|
||||
|
||||
def maybe_make_int_tensor(
|
||||
_list: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
_list: list[int] | None,
|
||||
device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert Python int list to a 1D int torch.Tensor on `device`
|
||||
@@ -170,8 +170,8 @@ def maybe_make_int_tensor(
|
||||
|
||||
|
||||
def maybe_make_long_tensor(
|
||||
_list: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
_list: list[int] | None,
|
||||
device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert Python int list to a 1D long torch.Tensor on `device`
|
||||
@@ -186,7 +186,7 @@ def maybe_make_long_tensor(
|
||||
)
|
||||
|
||||
|
||||
def maybe_max(_list: Optional[list]) -> Optional[Number]:
|
||||
def maybe_max(_list: list | None) -> Number | None:
|
||||
"""
|
||||
Returns:
|
||||
|
||||
@@ -241,9 +241,9 @@ def ref_masked_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
custom_mask: Optional[torch.Tensor] = None,
|
||||
q_seq_lens: Optional[list] = None,
|
||||
kv_seq_lens: Optional[list] = None,
|
||||
custom_mask: torch.Tensor | None = None,
|
||||
q_seq_lens: list | None = None,
|
||||
kv_seq_lens: list | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
"Golden" masked attention reference. Supports two types of masking:
|
||||
@@ -302,11 +302,11 @@ def ref_masked_attention(
|
||||
def make_qkv(
|
||||
batch_size: int,
|
||||
max_q_seq_len: int,
|
||||
max_kv_seq_len: Optional[int],
|
||||
max_kv_seq_len: int | None,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
device: Union[torch.device, str],
|
||||
force_kv_seq_lens: Optional[list[int]] = None,
|
||||
device: torch.device | str,
|
||||
force_kv_seq_lens: list[int] | None = None,
|
||||
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
||||
force_max_len: bool = False,
|
||||
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
|
||||
@@ -436,7 +436,7 @@ def make_qkv(
|
||||
|
||||
|
||||
def pack_tensor(
|
||||
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str]
|
||||
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: torch.device | str
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
"""
|
||||
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
||||
@@ -470,7 +470,7 @@ def pack_tensor(
|
||||
return packed_tensor, start_loc_list
|
||||
|
||||
|
||||
def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs:
|
||||
def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
|
||||
"""
|
||||
Individually pack each of Q, K and V, each with dimensions batch_size x
|
||||
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
|
||||
@@ -594,19 +594,19 @@ def make_alibi_bias(
|
||||
|
||||
|
||||
def _make_metadata_tensors(
|
||||
seq_lens: Optional[list[int]],
|
||||
context_lens: Optional[list[int]],
|
||||
encoder_seq_lens: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
seq_lens: list[int] | None,
|
||||
context_lens: list[int] | None,
|
||||
encoder_seq_lens: list[int] | None,
|
||||
device: torch.device | str,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Any,
|
||||
Any,
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor | None,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Optional[int],
|
||||
int | None,
|
||||
]:
|
||||
"""
|
||||
Build scalar & tensor values required to build attention metadata structure.
|
||||
@@ -678,7 +678,7 @@ def make_kv_cache(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
device: Union[torch.device, str],
|
||||
device: torch.device | str,
|
||||
backend: str,
|
||||
default_val: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
@@ -726,18 +726,18 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
|
||||
return (num_tokens + block_size) // block_size
|
||||
|
||||
|
||||
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
|
||||
def make_empty_slot_mapping_tensor(device: torch.device | str):
|
||||
return maybe_make_long_tensor([], device)
|
||||
|
||||
|
||||
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
|
||||
def make_empty_block_tables_tensor(device: torch.device | str):
|
||||
return torch.tensor([], device=device)
|
||||
|
||||
|
||||
def split_slot_mapping(
|
||||
slot_mapping_list: torch.Tensor,
|
||||
seq_lens: list[int],
|
||||
device: Union[torch.device, str],
|
||||
device: torch.device | str,
|
||||
):
|
||||
"""
|
||||
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
||||
@@ -799,7 +799,7 @@ def split_slot_mapping(
|
||||
def make_block_tables_slot_mapping(
|
||||
block_size: int,
|
||||
seq_lens: list[int],
|
||||
device: Union[torch.device, str],
|
||||
device: torch.device | str,
|
||||
block_base_addr: int = 0,
|
||||
) -> tuple[torch.Tensor, list[int], int]:
|
||||
"""
|
||||
@@ -880,11 +880,11 @@ def make_block_tables_slot_mapping(
|
||||
def make_test_metadata(
|
||||
attn_backend: _Backend,
|
||||
is_prompt: bool,
|
||||
seq_lens: Optional[list[int]],
|
||||
decoder_test_params: Optional[PhaseTestParameters],
|
||||
device: Union[torch.device, str],
|
||||
encoder_test_params: Optional[PhaseTestParameters] = None,
|
||||
cross_test_params: Optional[PhaseTestParameters] = None,
|
||||
seq_lens: list[int] | None,
|
||||
decoder_test_params: PhaseTestParameters | None,
|
||||
device: torch.device | str,
|
||||
encoder_test_params: PhaseTestParameters | None = None,
|
||||
cross_test_params: PhaseTestParameters | None = None,
|
||||
) -> AttentionMetadata:
|
||||
"""
|
||||
Construct fake attention metadata for a given test phase
|
||||
@@ -1142,16 +1142,16 @@ def torch_experts(
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
b_bias1: Optional[torch.Tensor] = None,
|
||||
b_bias2: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
b_bias1: torch.Tensor | None = None,
|
||||
b_bias2: torch.Tensor | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
apply_router_weights_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
@@ -1261,10 +1261,10 @@ def torch_moe(
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
b_bias1: Optional[torch.Tensor] = None,
|
||||
b_bias2: Optional[torch.Tensor] = None,
|
||||
b_bias1: torch.Tensor | None = None,
|
||||
b_bias2: torch.Tensor | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
@@ -1298,15 +1298,13 @@ def torch_moe_single(a, w, score, topk):
|
||||
# A special version of op check that has a restricted default set of test_utils
|
||||
# and a patched version of allclose that supports fp8 types.
|
||||
def opcheck(
|
||||
op: Union[
|
||||
torch._ops.OpOverload,
|
||||
torch._ops.OpOverloadPacket,
|
||||
torch._library.custom_ops.CustomOpDef,
|
||||
],
|
||||
op: torch._ops.OpOverload
|
||||
| torch._ops.OpOverloadPacket
|
||||
| torch._library.custom_ops.CustomOpDef,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
||||
test_utils: str | Sequence[str] = ALL_OPCHECK_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
cond: bool = True,
|
||||
) -> dict[str, str]:
|
||||
@@ -1338,7 +1336,7 @@ def baseline_scaled_mm(
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||
# in numpy simply stretches dimensions with an extent of 1 to match
|
||||
|
||||
Reference in New Issue
Block a user