Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details.
|
||||
Run `pytest tests/quantization/test_register_quantization_config.py`.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -37,10 +37,10 @@ class FakeQuantLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: "torch.nn.Module",
|
||||
x: "torch.Tensor",
|
||||
bias: Optional["torch.Tensor"] = None,
|
||||
) -> "torch.Tensor":
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Perform fake quantization before the linear layer."""
|
||||
|
||||
# Calculate the scales dynamically
|
||||
@@ -72,7 +72,7 @@ class CustomQuantConfig(QuantizationConfig):
|
||||
"""Name of the quantization method."""
|
||||
return "custom_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list["torch.dtype"]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@@ -92,8 +92,8 @@ class CustomQuantConfig(QuantizationConfig):
|
||||
return CustomQuantConfig(num_bits=config.get("num_bits", 8))
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: "torch.nn.Module", prefix: str
|
||||
) -> Optional["FakeQuantLinearMethod"]:
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> FakeQuantLinearMethod | None:
|
||||
"""Get the quantize method to use for the quantized layer."""
|
||||
if isinstance(layer, LinearBase):
|
||||
return FakeQuantLinearMethod(num_bits=self.num_bits)
|
||||
|
||||
Reference in New Issue
Block a user