Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,8 +5,8 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -39,7 +39,7 @@ from vllm.utils.deep_gemm import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.dtype
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
@@ -54,7 +54,7 @@ def cutlass_scaled_mm(
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
is_hopper: Optional[bool] = None,
|
||||
is_hopper: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
if is_hopper is None:
|
||||
is_hopper = current_platform.is_device_capability(90)
|
||||
@@ -279,8 +279,8 @@ class W8A8BlockFp8LinearOp:
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
@@ -394,7 +394,7 @@ class W8A8BlockFp8LinearOp:
|
||||
],
|
||||
torch.Tensor,
|
||||
],
|
||||
Optional[QuantFP8],
|
||||
QuantFP8 | None,
|
||||
]:
|
||||
if use_cutlass:
|
||||
return self._run_cutlass, (
|
||||
@@ -418,7 +418,7 @@ class W8A8BlockFp8LinearOp:
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor, dtype: Optional[torch.dtype] = None
|
||||
x: torch.Tensor, dtype: torch.dtype | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to float8 values "
|
||||
"with tensor-wise quantization."""
|
||||
@@ -568,10 +568,10 @@ def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
column_major_scales: bool = False,
|
||||
out_q: Optional[torch.Tensor] = None,
|
||||
use_ue8m0: Optional[bool] = None,
|
||||
out_q: torch.Tensor | None = None,
|
||||
use_ue8m0: bool | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
@@ -754,7 +754,7 @@ def _w8a8_triton_block_scaled_mm(
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_fp8_configs(
|
||||
N: int, K: int, block_n: int, block_k: int
|
||||
) -> Optional[dict[int, Any]]:
|
||||
) -> dict[int, Any] | None:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
The return value will be a dictionary that maps an irregular grid of
|
||||
@@ -1012,7 +1012,7 @@ def validate_fp8_block_shape(
|
||||
def create_fp8_weight_parameter(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
weight_loader: Optional[Callable],
|
||||
weight_loader: Callable | None,
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create FP8 weight parameter."""
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
@@ -1033,8 +1033,8 @@ def create_fp8_scale_parameter(
|
||||
parameter_type: torch.nn.Parameter,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
block_size: Optional[list[int]],
|
||||
weight_loader: Optional[Callable],
|
||||
block_size: list[int] | None,
|
||||
weight_loader: Callable | None,
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create scale parameter based on quantization strategy."""
|
||||
if parameter_type == ChannelQuantScaleParameter:
|
||||
@@ -1070,7 +1070,7 @@ def create_fp8_scale_parameter(
|
||||
|
||||
|
||||
def create_fp8_input_scale(
|
||||
output_partition_sizes: list[int], weight_loader: Optional[Callable]
|
||||
output_partition_sizes: list[int], weight_loader: Callable | None
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create input scale parameter for static activation quantization."""
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
@@ -1087,8 +1087,8 @@ def process_fp8_weight_tensor_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: list[int],
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Process weights for tensor-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
@@ -1114,8 +1114,8 @@ def process_fp8_weight_tensor_strategy(
|
||||
def process_fp8_weight_channel_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Process weights for channel-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
|
||||
Reference in New Issue
Block a user