Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,7 +20,7 @@ class MPLinearLayerConfig:
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
out_type: Optional[torch.dtype] = None
|
||||
out_type: torch.dtype | None = None
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
@@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
@@ -39,8 +39,8 @@ class MPLinearKernel(ABC):
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
w_zp_param_name: str | None = None,
|
||||
w_gidx_param_name: str | None = None,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
@@ -62,12 +62,12 @@ class MPLinearKernel(ABC):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(
|
||||
self, layer: torch.nn.Module, name: Optional[str], fn: Callable
|
||||
self, layer: torch.nn.Module, name: str | None, fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
old_param = getattr(layer, name)
|
||||
@@ -83,8 +83,8 @@ class MPLinearKernel(ABC):
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor], # w_gidx
|
||||
torch.Tensor | None, # w_zp,
|
||||
torch.Tensor | None, # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel,
|
||||
@@ -48,7 +46,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig, compute_capability: Optional[int] = None
|
||||
config: MPLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,7 +21,7 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
@@ -87,7 +86,7 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -44,9 +43,9 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
bitblas_quant_config: Optional[QuantizationConfig] = None,
|
||||
w_zp_param_name: str | None = None,
|
||||
w_gidx_param_name: str | None = None,
|
||||
bitblas_quant_config: QuantizationConfig | None = None,
|
||||
):
|
||||
self.quant_config = bitblas_quant_config
|
||||
super().__init__(
|
||||
@@ -57,7 +56,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
b_q_weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: Optional[torch.Tensor] = None,
|
||||
qzeros: torch.Tensor | None = None,
|
||||
):
|
||||
from bitblas.quantization.utils import general_compress
|
||||
|
||||
@@ -82,7 +81,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
# qzeros should be de-quantized to int zeros.
|
||||
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
|
||||
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
|
||||
zeros: Optional[torch.Tensor] = None
|
||||
zeros: torch.Tensor | None = None
|
||||
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
|
||||
if zeros_mode == "original":
|
||||
zeros = intzeros.to(torch.float16).contiguous()
|
||||
@@ -113,7 +112,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
is_bitblas_installed = True
|
||||
|
||||
try:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from importlib.util import find_spec
|
||||
from typing import Final, Optional
|
||||
from typing import Final
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,7 +26,7 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||
error_msg = (
|
||||
f"Weight type ({c.weight_type}) not supported by "
|
||||
@@ -76,7 +76,7 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,7 +25,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
@@ -95,7 +94,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,7 +19,7 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
@@ -95,7 +94,7 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,7 +24,7 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
@@ -137,7 +136,7 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -28,7 +27,7 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Machete uses CUTLASS, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Machete only supported on CUDA"
|
||||
@@ -129,7 +128,7 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -32,7 +31,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Marlin uses inline PTX, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Marlin only supported on CUDA"
|
||||
@@ -144,7 +143,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -23,7 +22,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
@@ -52,7 +51,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -61,9 +60,9 @@ class ScaledMMLinearKernel(ABC):
|
||||
) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
Optional[torch.Tensor], # input_scale,
|
||||
Optional[torch.Tensor], # input_zp
|
||||
Optional[torch.Tensor], # azp_adj
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel,
|
||||
@@ -35,7 +34,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig, compute_capability: Optional[int] = None
|
||||
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,7 +18,7 @@ def rocm_aiter_gemm_w8a8_impl(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import gemm_a8w8_CK
|
||||
@@ -36,7 +35,7 @@ def rocm_aiter_gemm_w8a8_fake(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
@@ -59,7 +58,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
@@ -99,7 +98,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
`AiterScaledMMLinearKernel` implements a fused version of
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -24,7 +23,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUScaledMM requires running on CPU."
|
||||
|
||||
@@ -173,7 +172,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.linear_method(
|
||||
layer,
|
||||
@@ -185,7 +184,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
@@ -207,7 +206,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,7 +20,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CutlassScaledMM requires running on CUDA."
|
||||
|
||||
@@ -110,7 +109,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,7 +16,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
@@ -38,6 +37,6 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return super().apply_weights(layer, x, bias)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond # noqa: F401
|
||||
@@ -25,7 +24,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
@@ -77,17 +76,17 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
|
||||
)
|
||||
|
||||
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
|
||||
return x
|
||||
|
||||
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
|
||||
return x + bias
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user