Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -15,13 +14,13 @@ ROCM_FP8FNUZ_MAX = 224.0
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def ref_dynamic_per_token_quant(
|
||||
x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None
|
||||
) -> tuple[torch.tensor, torch.tensor]:
|
||||
x: torch.Tensor, quant_dtype: torch.dtype, scale_ub: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
@@ -76,8 +75,8 @@ def ref_dynamic_per_token_quant(
|
||||
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
||||
# kernel
|
||||
def ref_dynamic_per_tensor_fp8_quant(
|
||||
x: torch.tensor,
|
||||
) -> tuple[torch.tensor, torch.tensor]:
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = (
|
||||
ROCM_FP8FNUZ_MAX
|
||||
@@ -250,10 +249,10 @@ def per_block_cast_to_int8(
|
||||
|
||||
def dequant(
|
||||
t: torch.Tensor,
|
||||
scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]],
|
||||
scale: torch.Tensor | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
out_dtype: Optional[torch.dtype] = torch.float32,
|
||||
out_dtype: torch.dtype | None = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
f32 = torch.float32
|
||||
@@ -267,10 +266,10 @@ def dequant(
|
||||
|
||||
def batched_dequant(
|
||||
t: torch.Tensor,
|
||||
scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]],
|
||||
scale: torch.Tensor | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
out_dtype: Optional[torch.dtype] = torch.float32,
|
||||
out_dtype: torch.dtype | None = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
assert t.shape[0] == scale.shape[0]
|
||||
@@ -289,9 +288,9 @@ def native_batched_masked_quant_matmul(
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor] = None,
|
||||
B_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
A_scale: torch.Tensor | None = None,
|
||||
B_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
|
||||
Reference in New Issue
Block a user