Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -34,9 +33,9 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: Optional[bool] = None,
|
||||
has_zp: bool | None = None,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
device_capability: int | None = None,
|
||||
):
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
@@ -72,10 +71,10 @@ def query_marlin_supported_quant_types(
|
||||
|
||||
def _check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
group_size: int | None,
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
device_capability: int | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (
|
||||
@@ -109,7 +108,7 @@ def check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False,
|
||||
device_capability: Optional[int] = None,
|
||||
device_capability: int | None = None,
|
||||
) -> bool:
|
||||
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
||||
return cond
|
||||
@@ -164,7 +163,7 @@ def check_marlin_supports_shape(
|
||||
input_size_per_partition: int,
|
||||
input_size: int,
|
||||
group_size: int,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
) -> tuple[bool, str | None]:
|
||||
try:
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition, input_size_per_partition, input_size, group_size
|
||||
@@ -445,7 +444,7 @@ def apply_gptq_marlin_linear(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
@@ -494,7 +493,7 @@ def apply_awq_marlin_linear(
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
|
||||
Reference in New Issue
Block a user