Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
@@ -15,13 +14,13 @@ ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
def ref_dynamic_per_token_quant(
x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None
) -> tuple[torch.tensor, torch.tensor]:
x: torch.Tensor, quant_dtype: torch.dtype, scale_ub: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE
@@ -76,8 +75,8 @@ def ref_dynamic_per_token_quant(
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(
x: torch.tensor,
) -> tuple[torch.tensor, torch.tensor]:
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
@@ -250,10 +249,10 @@ def per_block_cast_to_int8(
def dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
scale: torch.Tensor | None,
block_shape: list[int] | None,
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
out_dtype: torch.dtype | None = torch.float32,
) -> torch.Tensor:
if scale is not None:
f32 = torch.float32
@@ -267,10 +266,10 @@ def dequant(
def batched_dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
scale: torch.Tensor | None,
block_shape: list[int] | None,
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
out_dtype: torch.dtype | None = torch.float32,
) -> torch.Tensor:
if scale is not None:
assert t.shape[0] == scale.shape[0]
@@ -289,9 +288,9 @@ def native_batched_masked_quant_matmul(
B: torch.Tensor,
C: torch.Tensor,
num_expert_tokens: torch.Tensor,
A_scale: Optional[torch.Tensor] = None,
B_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
A_scale: torch.Tensor | None = None,
B_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
per_act_token_quant: bool = False,
) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()