Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -3,7 +3,6 @@
"""Custom activation functions."""
import math
from typing import Optional
import torch
import torch.nn as nn
@@ -486,7 +485,7 @@ class ScaledActivation(nn.Module):
act_module: nn.Module,
intermediate_size: int,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
params_dtype: torch.dtype | None = None,
):
super().__init__()
self.act = act_module

View File

@@ -4,7 +4,7 @@ import contextlib
import os
from collections import namedtuple
from collections.abc import Callable
from typing import Any, Union
from typing import Any
import torch
@@ -138,7 +138,7 @@ def matmul_kernel_persistent(
def matmul_persistent(
a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
@@ -375,7 +375,7 @@ def mean_dim(
input: torch.Tensor,
dim: int,
keepdim: bool = False,
dtype: Union[torch.dtype, None] = None,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
"""
Triton implementation of torch.mean with single dimension reduction.
@@ -475,9 +475,7 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
return log_softmax(input, dim=dim)
def mean_batch_invariant(
input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None
):
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
result = input.to(torch.float32)

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import torch
from einops import rearrange
@@ -32,7 +31,7 @@ def chunk_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
@@ -86,7 +85,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -119,7 +118,7 @@ def chunk_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -257,12 +256,12 @@ def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
g: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]

View File

@@ -9,7 +9,6 @@
# ruff: noqa: E501
from typing import Optional
import torch
@@ -144,9 +143,9 @@ def chunk_fwd_o(
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -104,8 +103,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
g_cumsum: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import torch
@@ -163,9 +162,9 @@ def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T = g.shape
@@ -200,9 +199,9 @@ def chunk_local_cumsum_vector(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T, S = g.shape
@@ -248,9 +247,9 @@ def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
output_dtype: torch.dtype | None = torch.float,
**kwargs,
) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -169,9 +168,9 @@ def fused_recurrent_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
@@ -248,9 +247,9 @@ class FusedRecurrentFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = fused_recurrent_gated_delta_rule_fwd(
@@ -280,9 +279,9 @@ def fused_recurrent_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import os
from typing import Optional
import torch
@@ -90,7 +89,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None
):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])

View File

@@ -14,7 +14,6 @@
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from functools import lru_cache
from typing import Optional
import torch
import torch.nn as nn
@@ -324,10 +323,10 @@ class LayerNormGated(nn.Module):
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
group_size: int | None = None,
norm_before_gate: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
@@ -364,10 +363,10 @@ class RMSNormGated(nn.Module):
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
group_size: int | None = None,
norm_before_gate: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -407,7 +406,7 @@ def merge_16x16_to_64x64_inverse_kernel(
@input_guard
def solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""

View File

@@ -11,8 +11,9 @@ import contextlib
import functools
import logging
import os
from collections.abc import Callable
from enum import Enum
from typing import Any, Callable, Literal, Optional
from typing import Any, Literal
import torch
@@ -43,7 +44,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
A wrapped version of the input function with single-entry caching.
"""
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
cache_entries: tuple[tuple | None, dict | None, Any] = []
cache_size = 4
@functools.wraps(fn)

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -123,7 +122,7 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
cu_seqlens: torch.LongTensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any, Optional
from typing import Any
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import (
@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
_config: dict[str, Any] | None = None
@contextmanager
@@ -31,7 +31,7 @@ def override_config(config):
_config = old_config
def get_config() -> Optional[dict[str, Any]]:
def get_config() -> dict[str, Any] | None:
return _config

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -259,7 +258,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
@@ -282,12 +281,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert expert_tokens_meta is not None

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -110,7 +109,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
expert_tokens_metadata: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
@@ -148,12 +147,12 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
experts = (

View File

@@ -34,8 +34,8 @@ def _get_config_dtype_str(
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: Optional[str] = None,
) -> Optional[str]:
ocp_mx_scheme: str | None = None,
) -> str | None:
"""
Return a string used to construct the filename that contains the
tuning info for a particular quantization scheme. See
@@ -60,16 +60,16 @@ def _get_config_dtype_str(
def _quant_flags_to_group_shape(
quant_dtype: Union[torch.dtype, str, None],
quant_dtype: torch.dtype | str | None,
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[Optional[GroupShape], Optional[GroupShape]]:
block_shape: list[int] | None,
) -> tuple[GroupShape | None, GroupShape | None]:
"""
Convert MoE quantization flags into more generic GroupShapes.
"""
a_shape: Optional[GroupShape]
w_shape: Optional[GroupShape]
a_shape: GroupShape | None
w_shape: GroupShape | None
if block_shape is not None:
assert not per_act_token_quant
assert not per_out_ch_quant
@@ -100,7 +100,7 @@ class FusedMoEQuantDesc:
# The quantized type of this parameters. None means unquantized or
# already quantized.
# TODO (bnell): use scalar_type instead of Union.
dtype: Union[torch.dtype, str, None] = None
dtype: torch.dtype | str | None = None
# A field that describes the quantization group shape, from quant_utils.py.
# * (-1, -1) for per-tensor quantization
@@ -109,7 +109,7 @@ class FusedMoEQuantDesc:
# * (128, 128) for 128x128 deepseek style block quantization
# * (1, 128) for deepseek style activation quantization
# (i.e. per-token-per-group)
shape: Optional[GroupShape] = None
shape: GroupShape | None = None
# Quantization scales.
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
@@ -117,13 +117,13 @@ class FusedMoEQuantDesc:
# Quantization alphas or gscales, used for nvfp4 types.
# TODO(bnell): put some of these in subclasses
alpha_or_gscale: Optional[torch.Tensor] = None
alpha_or_gscale: torch.Tensor | None = None
# Zero points for int4/int8 types
zp: Optional[torch.Tensor] = None
zp: torch.Tensor | None = None
# Biases for GPT triton MoE
bias: Optional[torch.Tensor] = None
bias: torch.Tensor | None = None
# TODO(bnell): have subclasses for specific moe methods?
@@ -179,7 +179,7 @@ class FusedMoEQuantConfig:
#
@property
def quant_dtype(self) -> Union[torch.dtype, str, None]:
def quant_dtype(self) -> torch.dtype | str | None:
return self._a1.dtype
@property
@@ -203,7 +203,7 @@ class FusedMoEQuantConfig:
return self._a1.shape == GroupShape.PER_TENSOR
@property
def block_shape(self) -> Optional[list[int]]:
def block_shape(self) -> list[int] | None:
if (
self._a1.shape is not None
and self._a1.shape != GroupShape.PER_TENSOR
@@ -218,34 +218,34 @@ class FusedMoEQuantConfig:
return self.block_shape is not None
@property
def a1_scale(self) -> Optional[torch.Tensor]:
def a1_scale(self) -> torch.Tensor | None:
assert self._a1.scale is None or isinstance(self._a1.scale, torch.Tensor)
return self._a1.scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
def a1_gscale(self) -> torch.Tensor | None:
return self._a1.alpha_or_gscale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
def a2_scale(self) -> torch.Tensor | None:
assert self._a2.scale is None or isinstance(self._a2.scale, torch.Tensor)
return self._a2.scale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
def a2_gscale(self) -> torch.Tensor | None:
return self._a2.alpha_or_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
def w1_scale(self) -> torch.Tensor | None:
assert self._w1.scale is None or isinstance(self._w1.scale, torch.Tensor)
return self._w1.scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
def w1_zp(self) -> torch.Tensor | None:
return self._w1.zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
def w1_bias(self) -> torch.Tensor | None:
return self._w1.bias
@property
@@ -254,20 +254,20 @@ class FusedMoEQuantConfig:
return self._w1.scale
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
def g1_alphas(self) -> torch.Tensor | None:
return self._w1.alpha_or_gscale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
def w2_scale(self) -> torch.Tensor | None:
assert self._w2.scale is None or isinstance(self._w2.scale, torch.Tensor)
return self._w2.scale
@property
def w2_zp(self) -> Optional[torch.Tensor]:
def w2_zp(self) -> torch.Tensor | None:
return self._w2.zp
@property
def w2_bias(self) -> Optional[torch.Tensor]:
def w2_bias(self) -> torch.Tensor | None:
return self._w2.bias
@property
@@ -276,7 +276,7 @@ class FusedMoEQuantConfig:
return self._w2.scale
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
def g2_alphas(self) -> torch.Tensor | None:
return self._w2.alpha_or_gscale
@property
@@ -296,7 +296,7 @@ class FusedMoEQuantConfig:
return self._a1.dtype is None and self._w1.dtype == "int4"
@property
def ocp_mx_scheme(self) -> Union[str, None]:
def ocp_mx_scheme(self) -> str | None:
if not hasattr(self, "_ocp_mx_scheme"):
if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or (
self._w1.dtype is not None and not isinstance(self._w1.dtype, str)
@@ -322,7 +322,7 @@ class FusedMoEQuantConfig:
def use_nvfp4_w4a4(self) -> bool:
return self.quant_dtype == "nvfp4"
def config_name(self, dtype: torch.dtype) -> Optional[str]:
def config_name(self, dtype: torch.dtype) -> str | None:
"""
Return a string used to construct the filename that contains the
tuning info for a particular quantization scheme. See
@@ -340,7 +340,7 @@ class FusedMoEQuantConfig:
self,
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int]]:
) -> tuple[int, int] | None:
"""
Construct the proper activation scale shape for this
config.
@@ -363,7 +363,7 @@ class FusedMoEQuantConfig:
num_experts: int,
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int, int]]:
) -> tuple[int, int, int] | None:
"""
Construct the proper activation batched scale shape for this
config, e.g. (num experts, *scale_shape).
@@ -377,23 +377,23 @@ class FusedMoEQuantConfig:
@staticmethod
def make(
quant_dtype: Union[torch.dtype, str, None] = None,
quant_dtype: torch.dtype | str | None = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
g1_alphas: Optional[torch.Tensor] = None,
g2_alphas: Optional[torch.Tensor] = None,
a1_gscale: Optional[torch.Tensor] = None,
a2_gscale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
weight_dtype: Union[torch.dtype, str, None] = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
g1_alphas: torch.Tensor | None = None,
g2_alphas: torch.Tensor | None = None,
a1_gscale: torch.Tensor | None = None,
a2_gscale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
@@ -457,11 +457,11 @@ class FusedMoEQuantConfig:
def fp8_w8a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and fp8 weights.
@@ -481,8 +481,8 @@ def fp8_w8a8_moe_quant_config(
def int8_w8a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig:
"""
@@ -503,8 +503,8 @@ def int8_w8a8_moe_quant_config(
def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations and mxfp4 weights.
@@ -521,12 +521,12 @@ def ocp_mx_moe_quant_config(
quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
weight_dtype: Optional[str] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
weight_dtype: str | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
@@ -575,9 +575,9 @@ def nvfp4_moe_quant_config(
def int4_w4a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
block_shape: Optional[list[int]] = None,
w1_zp: torch.Tensor | None,
w2_zp: torch.Tensor | None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int4 weights.
@@ -595,9 +595,9 @@ def int4_w4a16_moe_quant_config(
def int8_w8a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
block_shape: Optional[list[int]] = None,
w1_zp: torch.Tensor | None,
w2_zp: torch.Tensor | None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-bit float activations and int8 weights.
@@ -613,8 +613,8 @@ def int8_w8a16_moe_quant_config(
def biased_moe_quant_config(
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations with biases.

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from torch.nn import functional as F
@@ -33,7 +33,7 @@ def grouped_topk(
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -88,12 +88,12 @@ def select_experts(
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_grouped_topk:
assert topk_group is not None
@@ -147,14 +147,14 @@ class IPEXFusedMOE:
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
@@ -189,14 +189,14 @@ class SGLFusedMOE:
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
@@ -247,14 +247,14 @@ class CPUFusedMOE:
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CUTLASS based Fused MoE kernels."""
from typing import Callable, Optional
from collections.abc import Callable
import torch
@@ -35,23 +35,23 @@ def run_cutlass_moe_fp8(
topk_ids: torch.Tensor,
activation_callable: Callable,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_num_tokens: torch.Tensor | None,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool,
topk_weights: Optional[torch.Tensor],
topk_weights: torch.Tensor | None,
):
a1q = hidden_states
@@ -249,7 +249,7 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: Optional[torch.dtype],
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
@@ -278,12 +278,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
@@ -331,7 +331,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
@@ -377,7 +377,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
@@ -390,7 +390,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
@@ -435,7 +435,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
@@ -457,7 +457,7 @@ def cutlass_moe_fp8(
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor:
@@ -768,7 +768,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
@@ -793,12 +793,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], # unused
a2_scale: Optional[torch.Tensor], # unused
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, # unused
a2_scale: torch.Tensor | None, # unused
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
@@ -839,7 +839,7 @@ def cutlass_moe_fp4(
n: int,
k: int,
e: int,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
assert expert_map is None, (
@@ -896,7 +896,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
inplace: bool,
activation: str,
apply_router_weight_on_input: bool,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from tqdm import tqdm
@@ -204,7 +203,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None
block_m = self.block_shape[0]
@@ -228,12 +227,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is not None
@@ -284,7 +283,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, act_out, mm1_out.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
a2q_scale: torch.Tensor | None = None
a2q, a2q_scale = per_token_group_quant_fp8(
act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out
)
@@ -317,9 +316,9 @@ def deep_gemm_moe_fp8(
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
apply_router_weight_on_input=False,
) -> torch.Tensor:
"""

View File

@@ -6,7 +6,6 @@ and updated to fit vllm needs and terminology.
"""
import functools
from typing import Optional
import torch
@@ -39,7 +38,7 @@ def compute_aligned_M(
num_topk: int,
local_num_experts: int,
alignment: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
):
if (expert_tokens_meta is not None) and (
expert_tokens_meta.expert_num_tokens_cpu is not None
@@ -175,7 +174,7 @@ def ep_scatter(
recv_x_scale: torch.Tensor,
recv_topk: torch.Tensor,
num_recv_tokens_per_expert: torch.Tensor,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
expert_start_loc: torch.Tensor,
output_tensor: torch.Tensor,
output_tensor_scale: torch.Tensor,
@@ -305,7 +304,7 @@ def ep_gather(
recv_topk_ids: torch.Tensor,
recv_topk_weight: torch.Tensor,
input_index: torch.Tensor,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
output_tensor: torch.Tensor,
):
num_warps = 2
@@ -346,9 +345,9 @@ def deepgemm_moe_permute(
aq_scale: torch.Tensor,
topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
aq_out: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
aq_out: torch.Tensor | None = None,
):
assert aq.ndim == 2
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
@@ -415,7 +414,7 @@ def deepgemm_unpermute_and_reduce(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
inv_perm: torch.Tensor,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
output: torch.Tensor,
):
return ep_gather(

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional, Union
from collections.abc import Callable
import deep_ep
import torch
@@ -77,18 +77,18 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64
def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
def _get_dispatch_config(self) -> deep_ep.Config | None:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_)
def _get_combine_config(self) -> Optional[deep_ep.Config]:
def _get_combine_config(self) -> deep_ep.Config | None:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_combine_config(self.num_dispatchers_)
@@ -96,11 +96,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _do_dispatch(
self,
tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
token_scales: torch.Tensor | None,
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1_scale: Optional[torch.Tensor],
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
) -> Callable:
has_scales = token_scales is not None
@@ -175,12 +175,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
expert_topk_ids: Optional[torch.Tensor],
token_data: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
expert_topk_ids: torch.Tensor | None,
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: Optional[torch.Tensor],
a1_scale: Optional[torch.Tensor],
expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if event.event is not None:
@@ -249,7 +249,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.ReceiverType:
@@ -294,7 +294,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
@@ -318,7 +318,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
do_async: bool,
) -> Optional[Callable]:
) -> Callable | None:
a2a_idx = dbo_current_ubatch_id()
handle = self.handles[a2a_idx]
assert handle is not None

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional, Union
from collections.abc import Callable
import deep_ep
import torch
@@ -67,7 +67,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handles: list[Optional[tuple]] = [None, None]
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers
def num_dispatchers(self) -> int:
@@ -80,18 +80,18 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> Optional[int]:
def max_num_tokens_per_rank(self) -> int | None:
return self.max_tokens_per_rank
def topk_indices_dtype(self) -> Optional[torch.dtype]:
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64
def _do_quant(
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.use_fp8_dispatch:
block_k = (
quant_config.block_shape[1]
@@ -137,7 +137,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, mk.ReceiverType]:
@@ -200,9 +200,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _receiver(
self,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
expert_num_tokens: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a1_scale: torch.Tensor | None,
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
@@ -220,7 +220,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -96,7 +95,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
@@ -133,13 +132,13 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert activation == "silu", (
"Only activation silu is supported in FlashInferExperts"
@@ -207,7 +206,7 @@ def flashinfer_cutlass_moe_fp4(
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
@@ -242,7 +241,7 @@ def flashinfer_cutlass_moe(
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
tp_rank: int = 0,
tp_size: int = 1,

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -39,10 +38,10 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
@@ -89,7 +88,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
@@ -164,7 +163,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -105,7 +104,7 @@ direct_register_custom_op(
def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
@@ -115,8 +114,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
@@ -163,7 +162,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
@@ -173,8 +172,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused batched MoE kernel."""
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -370,8 +368,8 @@ def invoke_moe_batched_triton_kernel(
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
@@ -379,7 +377,7 @@ def invoke_moe_batched_triton_kernel(
use_int4_w4a16: bool,
config: dict[str, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
):
assert not use_int4_w4a16
max_num_tokens = A.size(1)
@@ -500,10 +498,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> Optional[int]:
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
@@ -518,7 +516,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
@@ -674,7 +672,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
num_experts = local_num_experts
@@ -701,12 +699,12 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert hidden_states.dim() == 3
@@ -754,15 +752,15 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def batched_moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
num_tokens: int,
E: int,
N: int,
expert_num_tokens: torch.Tensor,
qtype: Optional[torch.dtype],
qtype: torch.dtype | None,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing():
# Note: this does a bunch of extra work because expert_num_tokens is
# ignored but it does support torch.compile + cudagraphs.
@@ -868,7 +866,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
num_experts = local_num_experts
@@ -888,12 +886,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# Check constraints.

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ."""
from typing import Optional
import torch
from typing_extensions import override
@@ -28,31 +26,31 @@ def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
bias1: Optional[torch.Tensor],
bias2: Optional[torch.Tensor],
bias1: torch.Tensor | None,
bias2: torch.Tensor | None,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: Optional[torch.Tensor],
gating_output: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: Optional[str] = "silu",
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
intermediate_cache13: Optional[torch.Tensor] = None,
intermediate_cache2: Optional[torch.Tensor] = None,
activation: str | None = "silu",
expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
g_idx2: torch.Tensor | None = None,
sort_indices1: torch.Tensor | None = None,
sort_indices2: torch.Tensor | None = None,
w1_zeros: torch.Tensor | None = None,
w2_zeros: torch.Tensor | None = None,
workspace: torch.Tensor | None = None,
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True,
output: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
inplace: bool = False,
) -> torch.Tensor:
"""
@@ -249,26 +247,26 @@ def fused_marlin_moe_fake(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: Optional[torch.Tensor],
gating_output: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
intermediate_cache13: Optional[torch.Tensor] = None,
intermediate_cache2: Optional[torch.Tensor] = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
expert_map: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
g_idx2: torch.Tensor | None = None,
sort_indices1: torch.Tensor | None = None,
sort_indices2: torch.Tensor | None = None,
w1_zeros: torch.Tensor | None = None,
w2_zeros: torch.Tensor | None = None,
workspace: torch.Tensor | None = None,
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True,
output: Optional[torch.Tensor] = None,
output: torch.Tensor | None = None,
inplace: bool = False,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -341,7 +339,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
@@ -374,12 +372,12 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert self.w1_scale is not None

View File

@@ -5,7 +5,8 @@
import functools
import json
import os
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
import torch
import torch.nn.functional as F
@@ -539,10 +540,10 @@ def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
@@ -555,8 +556,8 @@ def invoke_fused_moe_kernel(
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
B_bias: Optional[torch.Tensor] = None,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
@@ -808,7 +809,7 @@ def zero_experts_compute_triton(
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(
E: int, N: int, dtype: Optional[str], block_shape: Optional[list[int]] = None
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
@@ -823,10 +824,10 @@ def get_config_file_name(
def get_moe_configs(
E: int,
N: int,
dtype: Optional[str],
block_n: Optional[int] = None,
block_k: Optional[int] = None,
) -> Optional[dict[int, Any]]:
dtype: str | None,
block_n: int | None = None,
block_k: int | None = None,
) -> dict[int, Any] | None:
"""
Return optimized configurations for the fused MoE kernel.
@@ -965,8 +966,8 @@ def get_default_config(
N: int,
K: int,
topk: int,
dtype: Optional[str],
block_shape: Optional[list[int]] = None,
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
@@ -1016,9 +1017,9 @@ def try_get_optimal_moe_config(
w1_shape: tuple[int, ...],
w2_shape: tuple[int, ...],
top_k: int,
dtype: Optional[str],
dtype: str | None,
M: int,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config
@@ -1076,7 +1077,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: Optional[torch.dtype] = None,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
@@ -1135,7 +1136,7 @@ def grouped_topk(
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if (
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
@@ -1211,7 +1212,7 @@ def eplb_map_to_physical_and_record(
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: Optional[torch.dtype] = None,
indices_type: torch.dtype | None = None,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
@@ -1326,19 +1327,19 @@ def inplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: Optional[str] = None,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> None:
fused_experts_impl(
hidden_states,
@@ -1381,19 +1382,19 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: Optional[str] = None,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> None:
pass
@@ -1423,19 +1424,19 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: Optional[str] = None,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
@@ -1477,19 +1478,19 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: Optional[str] = None,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -1534,8 +1535,8 @@ def fused_experts(
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
quant_config: Optional[FusedMoEQuantConfig] = None,
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
) -> torch.Tensor:
@@ -1625,8 +1626,8 @@ GELU_NO_MUL: str = activation_without_mul("gelu")
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
ocp_mx_scheme: Optional[str],
) -> Union[None, torch.dtype, str]:
ocp_mx_scheme: str | None,
) -> None | torch.dtype | str:
"""
Get the quantization type based on the quantization strategy flags.
We don't have a quant_config at this point so we need to work backwards.
@@ -1660,19 +1661,19 @@ def fused_experts_impl(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: Optional[str] = None,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
@@ -1964,7 +1965,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))
@@ -1981,12 +1982,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# Check constraints.
@@ -2074,7 +2075,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
a2q_scale: Optional[torch.Tensor] = None
a2q_scale: torch.Tensor | None = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -80,10 +79,10 @@ def triton_kernel_moe_forward(
topk: int,
renormalize: bool,
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
quant_config: FusedMoEQuantConfig | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
routing_data, gather_idx, scatter_idx = routing(
gating_output, topk, sm_first=not renormalize
@@ -115,13 +114,13 @@ def triton_kernel_fused_experts(
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
quant_config: Optional[FusedMoEQuantConfig] = None,
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
a1q_scale: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
@@ -261,7 +260,7 @@ class OAITritonExperts(BaseOAITritonExperts):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
workspace1 = (M, K)
@@ -279,12 +278,12 @@ class OAITritonExperts(BaseOAITritonExperts):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
if expert_map is not None:

View File

@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Literal, Optional, Union, get_args, overload
from typing import Literal, get_args, overload
import torch
import torch.nn.functional as F
@@ -70,15 +70,15 @@ if current_platform.is_cuda_alike():
)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
FusedMoEPrepareAndFinalize = None # type: ignore
FusedMoEPermuteExpertsUnpermute = object # type: ignore
FusedMoEPrepareAndFinalize = object # type: ignore
def _eplb_map_to_physical_and_record(
topk_ids: torch.Tensor,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: Optional[torch.dtype],
indices_type: torch.dtype | None,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
@@ -110,8 +110,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.moe_quant_config: FusedMoEQuantConfig | None = None
self.fused_experts: FusedMoEModularKernel | None = None
self.topk_indices_dtype = None
@abstractmethod
@@ -139,12 +139,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@staticmethod
def _maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: Optional[FusedMoEQuantConfig],
) -> Optional[FusedMoEPrepareAndFinalize]:
quant_config: FusedMoEQuantConfig | None,
) -> FusedMoEPrepareAndFinalize | None:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO: could allow this now
assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py"
@@ -229,7 +229,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return prepare_finalize
def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]:
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
if self.moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(
self.moe, self.moe_quant_config
@@ -280,7 +280,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
raise NotImplementedError
@property
@@ -296,21 +296,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@@ -368,7 +368,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
self.flashinfer_cutlass_moe = None # type: ignore
def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]:
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
else:
@@ -532,21 +532,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
@@ -578,7 +578,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
@@ -595,21 +595,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -705,21 +705,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
enable_eplb is not False
or expert_load_view is not None
@@ -754,21 +754,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
enable_eplb is not False
or expert_load_view is not None
@@ -795,21 +795,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
@@ -860,7 +860,7 @@ def determine_expert_map(
ep_rank: int,
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
) -> tuple[int, Optional[torch.Tensor]]:
) -> tuple[int, torch.Tensor | None]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
@@ -941,7 +941,7 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
def maybe_roundup_hidden_size(
hidden_size: int,
act_dtype: torch.dtype,
quant_config: Optional[QuantizationConfig],
quant_config: QuantizationConfig | None,
moe_parallel_config: FusedMoEParallelConfig,
) -> int:
"""
@@ -1016,30 +1016,30 @@ class FusedMoE(CustomOp):
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
params_dtype: torch.dtype | None = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
num_expert_group: int | None = None,
topk_group: int | None = None,
quant_config: QuantizationConfig | None = None,
tp_size: int | None = None,
ep_size: int | None = None,
dp_size: int | None = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
zero_expert_num: int | None = 0,
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
):
super().__init__()
if params_dtype is None:
@@ -1092,9 +1092,9 @@ class FusedMoE(CustomOp):
self.layer_name = prefix
self.enable_eplb = enable_eplb
self.expert_load_view: Optional[torch.Tensor] = None
self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None
self.expert_load_view: torch.Tensor | None = None
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
# Determine expert maps
if self.use_ep:
@@ -1128,7 +1128,7 @@ class FusedMoE(CustomOp):
)
expert_placement_strategy = "linear"
self.expert_map: Optional[torch.Tensor]
self.expert_map: torch.Tensor | None
local_num_experts, expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
@@ -1187,12 +1187,12 @@ class FusedMoE(CustomOp):
has_bias=has_bias,
)
self.moe_config = moe
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.moe_quant_config: FusedMoEQuantConfig | None = None
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method: QuantizeMethodBase | None = None
quant_method = (
UnquantizedFusedMoEMethod(moe)
if quant_config is None
@@ -1238,8 +1238,8 @@ class FusedMoE(CustomOp):
self.quant_method.create_weights(layer=self, **moe_quant_params)
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
if self.use_dp_chunking:
states_shape: tuple[int, ...]
@@ -1262,7 +1262,7 @@ class FusedMoE(CustomOp):
)
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
def shared_experts(self) -> torch.nn.Module | None:
return None
@property
@@ -1534,7 +1534,7 @@ class FusedMoE(CustomOp):
shard_id: str,
expert_id: int,
return_success: bool = False,
) -> Optional[bool]:
) -> bool | None:
if self.quant_config and self.quant_config.get_name() == "mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
@@ -1851,21 +1851,21 @@ class FusedMoE(CustomOp):
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
e_score_correction_bias: torch.Tensor | None = None,
indices_type: torch.dtype | None = None,
enable_eplb: bool = False,
expert_map: Optional[torch.Tensor] = None,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
zero_expert_num: Optional[int] = None,
zero_expert_type: Optional[str] = None,
expert_map: torch.Tensor | None = None,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
global_num_experts: int | None = None,
zero_expert_num: int | None = None,
zero_expert_type: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
@@ -2006,7 +2006,7 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
hidden_states = F.pad(
@@ -2047,14 +2047,14 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
@@ -2200,7 +2200,7 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
self.ensure_moe_quant_config()

View File

@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Callable, Optional, Union, final
from typing import final
import torch
@@ -81,7 +82,7 @@ class ExpertTokensMetadata:
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]
expert_num_tokens_cpu: torch.Tensor | None
@staticmethod
def make_from_list(
@@ -104,7 +105,7 @@ class TopKWeightAndReduce(ABC):
@abstractmethod
def apply(
self,
output: Optional[torch.Tensor],
output: torch.Tensor | None,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
@@ -132,10 +133,10 @@ class TopKWeightAndReduce(ABC):
#
PrepareResultType = tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor | None,
ExpertTokensMetadata | None,
torch.Tensor | None,
torch.Tensor | None,
]
ReceiverType = Callable[[], PrepareResultType]
@@ -155,7 +156,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> PrepareResultType:
@@ -195,10 +196,10 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
) -> tuple[Callable, ReceiverType] | ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
@@ -270,7 +271,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> Union[tuple[Callable, Callable], Callable]:
) -> tuple[Callable, Callable] | Callable:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
@@ -314,7 +315,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> Optional[torch.dtype]:
def topk_indices_dtype(self) -> torch.dtype | None:
"""
The PrepareFinalize All2All implementations generally constrain the
dtype of the topk_ids they support. This function returns the
@@ -324,7 +325,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
@abstractmethod
def max_num_tokens_per_rank(self) -> Optional[int]:
def max_num_tokens_per_rank(self) -> int | None:
"""
Some PrepareFinalize All2All implementations are batched. Meaning,
they can process only as set of tokens at a time. This
@@ -423,11 +424,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
#
@property
def quant_dtype(self) -> Optional[torch.dtype]:
def quant_dtype(self) -> torch.dtype | None:
return self.quant_config.quant_dtype
@property
def block_shape(self) -> Optional[list[int]]:
def block_shape(self) -> list[int] | None:
return self.quant_config.block_shape
@property
@@ -439,51 +440,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return self.quant_config.per_out_ch_quant
@property
def a1_scale(self) -> Optional[torch.Tensor]:
def a1_scale(self) -> torch.Tensor | None:
return self.quant_config.a1_scale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
def a2_scale(self) -> torch.Tensor | None:
return self.quant_config.a2_scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
def a1_gscale(self) -> torch.Tensor | None:
return self.quant_config.a1_gscale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
def a2_gscale(self) -> torch.Tensor | None:
return self.quant_config.a2_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
def w1_scale(self) -> torch.Tensor | None:
return self.quant_config.w1_scale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
def w2_scale(self) -> torch.Tensor | None:
return self.quant_config.w2_scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
def w1_zp(self) -> torch.Tensor | None:
return self.quant_config.w1_zp
@property
def w2_zp(self) -> Optional[torch.Tensor]:
def w2_zp(self) -> torch.Tensor | None:
return self.quant_config.w2_zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
def w1_bias(self) -> torch.Tensor | None:
return self.quant_config.w1_bias
@property
def w2_bias(self) -> Optional[torch.Tensor]:
def w2_bias(self) -> torch.Tensor | None:
return self.quant_config.w2_bias
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
def g1_alphas(self) -> torch.Tensor | None:
return self.quant_config.g1_alphas
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
def g2_alphas(self) -> torch.Tensor | None:
return self.quant_config.g2_alphas
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@@ -517,7 +518,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[ExpertTokensMetadata],
expert_tokens_meta: ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
@@ -578,12 +579,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
expert_tokens_meta: ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
) -> None:
"""
@@ -625,8 +626,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def _slice_scales(
scales: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
if scales is not None:
if scales.numel() == 1:
return scales
@@ -688,7 +689,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
shared_experts: torch.nn.Module | None = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
@@ -741,7 +742,7 @@ class FusedMoEModularKernel(torch.nn.Module):
top_k: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[ExpertTokensMetadata],
expert_tokens_meta: ExpertTokensMetadata | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Allocate temporary and output buffers for the fused experts op.
@@ -825,11 +826,11 @@ class FusedMoEModularKernel(torch.nn.Module):
@staticmethod
def _slice_expert_tokens_metadata(
num_chunks: int,
full_expert_tokens_meta: Optional[ExpertTokensMetadata],
full_expert_tokens_meta: ExpertTokensMetadata | None,
chunk_topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
) -> Optional[ExpertTokensMetadata]:
expert_map: torch.Tensor | None,
) -> ExpertTokensMetadata | None:
if num_chunks == 1 or full_expert_tokens_meta is None:
return full_expert_tokens_meta
@@ -861,12 +862,12 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
torch.Tensor | None,
ExpertTokensMetadata | None,
torch.Tensor,
torch.Tensor,
]:
@@ -945,7 +946,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self,
in_dtype: torch.dtype,
a1q: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
a1q_scale: torch.Tensor | None,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
@@ -953,9 +954,9 @@ class FusedMoEModularKernel(torch.nn.Module):
activation: str,
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
expert_tokens_meta: Optional[ExpertTokensMetadata],
expert_tokens_meta: ExpertTokensMetadata | None,
) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
@@ -1042,12 +1043,12 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
shared_output: Optional[torch.Tensor] = None
shared_output: torch.Tensor | None = None
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
@@ -1112,9 +1113,9 @@ class FusedMoEModularKernel(torch.nn.Module):
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -13,7 +12,7 @@ def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -13,14 +12,12 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[
torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
]:
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
@@ -33,7 +30,7 @@ def _moe_permute(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: Optional[torch.Tensor] = None
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
@@ -53,7 +50,7 @@ def _moe_permute(
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
@@ -73,17 +70,15 @@ def _moe_unpermute_and_reduce(
def moe_permute(
hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
a1q_scale: torch.Tensor | None,
topk_ids: torch.Tensor,
n_expert: int,
n_local_expert: int = -1,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
permuted_hidden_states: Optional[torch.Tensor] = None,
) -> tuple[
torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
]:
permuted_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This function expands and permutes activation to gather uncontinuous tokens
for each expert.
@@ -198,7 +193,7 @@ def moe_unpermute(
permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
inv_permuted_idx: torch.Tensor,
expert_first_token_offset: Optional[torch.Tensor] = None,
expert_first_token_offset: torch.Tensor | None = None,
) -> None:
"""
This function expands and permutes activation to gathering uncontinuous

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional, Union
from collections.abc import Callable
import pplx_kernels as pplx
import torch
@@ -24,9 +24,9 @@ def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
quant_dtype: torch.dtype | str | None,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
block_shape: list[int] | None,
):
# All pplx byte sizes must be 16-byte aligned.
align = 16
@@ -82,10 +82,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> Optional[int]:
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.uint32
def num_dispatchers(self) -> int:
@@ -103,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, mk.ReceiverType]:
@@ -148,7 +148,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
)
orig_a_scale_block_shape: Optional[int] = None
orig_a_scale_block_shape: int | None = None
if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1
@@ -184,7 +184,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
device=device,
)
expert_x_scale: Optional[torch.Tensor] = None
expert_x_scale: torch.Tensor | None = None
if a1q.dtype.itemsize == 1:
if quant_config.is_per_act_token:
# (M x 1) -> (E x M x K)
@@ -212,7 +212,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# This argument is optional, defaults to indices.size(0)
# There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None
bound_m: torch.Tensor | None = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
@@ -252,8 +252,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: Optional[torch.Tensor],
orig_a_scale_block_shape: Optional[int],
expert_x_scale: torch.Tensor | None,
orig_a_scale_block_shape: int | None,
) -> mk.PrepareResultType:
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
@@ -271,7 +271,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
@@ -302,7 +302,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: Optional[torch.Tensor] = None
bound_m: torch.Tensor | None = None
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
# num_tokens = output.size(0) # M

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -18,10 +17,10 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
@@ -36,7 +35,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
from functools import cache
from typing import Optional
import torch
@@ -53,13 +52,13 @@ def rocm_aiter_asm_moe_tkw1_impl(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
from aiter import ActivationType
@@ -90,13 +89,13 @@ def rocm_aiter_asm_moe_tkw1_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -206,14 +205,14 @@ def rocm_aiter_fused_moe_impl(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: Optional[torch.Tensor] = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
@@ -244,14 +243,14 @@ def rocm_aiter_fused_moe_fake(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: Optional[torch.Tensor] = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -300,7 +299,7 @@ def rocm_aiter_grouped_topk(
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0]
device = hidden_states.device
@@ -342,8 +341,8 @@ def rocm_aiter_fused_experts(
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
expert_map: Optional[torch.Tensor] = None,
quant_config: Optional[FusedMoEQuantConfig] = None,
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG

View File

@@ -10,7 +10,7 @@ like uniform random routing.
"""
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any
import torch
@@ -24,7 +24,7 @@ class RoutingStrategy(ABC):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
indices_type: Optional[torch.dtype] = None,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route tokens to experts.
@@ -89,7 +89,7 @@ class DistributionBasedRouting(RoutingStrategy):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
indices_type: Optional[torch.dtype] = None,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Randomly select experts for each token using the specified distribution.
@@ -269,7 +269,7 @@ class RoutingSimulator:
router_logits: torch.Tensor,
strategy_name: str,
top_k: int,
indices_type: Optional[torch.dtype] = None,
indices_type: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Simulate token-to-expert routing using the specified strategy.

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -18,7 +17,7 @@ class SharedFusedMoE(FusedMoE):
def __init__(
self,
shared_experts: Optional[torch.nn.Module],
shared_experts: torch.nn.Module | None,
use_overlapped: bool = True,
**kwargs,
):
@@ -35,7 +34,7 @@ class SharedFusedMoE(FusedMoE):
)
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None
def forward(

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -29,7 +28,7 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
def apply(
self,
output: Optional[torch.Tensor],
output: torch.Tensor | None,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
@@ -52,7 +51,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
def apply(
self,
output: Optional[torch.Tensor],
output: torch.Tensor | None,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
@@ -84,7 +83,7 @@ class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
def apply(
self,
output: Optional[torch.Tensor],
output: torch.Tensor | None,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
@@ -133,7 +132,7 @@ class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
def apply(
self,
output: Optional[torch.Tensor],
output: torch.Tensor | None,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -89,7 +88,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
@@ -128,12 +127,12 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
use_deep_gemm = self.allow_deep_gemm and (

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -58,7 +57,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
@@ -100,12 +99,12 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
topk = topk_ids.size(-1)

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod
from typing import Optional, Union
import torch
@@ -60,7 +59,7 @@ def _count_expert_num_tokens(
def count_expert_num_tokens(
topk_ids: torch.Tensor, num_local_experts: int, expert_map: Optional[torch.Tensor]
topk_ids: torch.Tensor, num_local_experts: int, expert_map: torch.Tensor | None
) -> torch.Tensor:
"""
Count the number to tokens assigned to each expert.
@@ -112,7 +111,7 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
def _nvfp4_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return flashinfer_fp4_quantize(
@@ -122,9 +121,9 @@ def _nvfp4_quantize(
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
per_act_token: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
@@ -148,9 +147,9 @@ def _fp8_quantize(
def _int8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
per_act_token: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform int8 quantization on the inputs. If a block_shape
@@ -175,9 +174,9 @@ def _int8_quantize(
def _mxfp4_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
# TODO: native mxfp4 is currently not integrated in vllm,
@@ -191,9 +190,9 @@ def _mxfp4_quantize(
def _mxfp8_e4m3_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
@@ -203,9 +202,9 @@ def _mxfp8_e4m3_quantize(
def _mxfp6_e3m2_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
@@ -220,9 +219,9 @@ def _mxfp6_e3m2_quantize(
def _mxfp6_e2m3_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
A_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
@@ -237,12 +236,12 @@ def _mxfp6_e2m3_quantize(
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
quant_dtype: Union[None, torch.dtype, str],
A_scale: torch.Tensor | None,
quant_dtype: None | torch.dtype | str,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
is_fp4_scale_swizzled: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor | None]:
if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
@@ -273,7 +272,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
return m[idx, ...]
def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
def normalize_scales_shape(scales: torch.Tensor | None) -> torch.Tensor | None:
if scales is not None:
if scales.numel() == 1:
scales = scales.view(1, 1)
@@ -283,9 +282,9 @@ def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Ten
def normalize_batched_scales_shape(
scales: Optional[torch.Tensor],
scales: torch.Tensor | None,
num_experts: int,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
if scales is not None and scales.ndim < 3:
if scales.numel() == 1:
scales = scales.view(1)
@@ -300,9 +299,9 @@ def normalize_batched_scales_shape(
def _validate_scale_shape(
a: torch.Tensor,
a_scale: Optional[torch.Tensor],
a_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
block_shape: list[int] | None,
) -> None:
if a_scale is None:
return

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom normalization layers."""
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -159,9 +157,9 @@ class RMSNorm(CustomOp):
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
var_hidden_size: int | None = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
@@ -190,8 +188,8 @@ class RMSNorm(CustomOp):
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
@@ -231,8 +229,8 @@ class RMSNorm(CustomOp):
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
@@ -247,8 +245,8 @@ class RMSNorm(CustomOp):
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
@@ -263,8 +261,8 @@ class RMSNorm(CustomOp):
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
@@ -313,8 +311,8 @@ class GemmaRMSNorm(CustomOp):
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
@@ -337,16 +335,16 @@ class GemmaRMSNorm(CustomOp):
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from einops import rearrange
@@ -529,7 +528,7 @@ def lightning_attention(
v: torch.Tensor,
ed: torch.Tensor,
block_size: int = 256,
kv_history: Optional[torch.Tensor] = None,
kv_history: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply lightning attention algorithm

View File

@@ -3,7 +3,7 @@
import itertools
from abc import abstractmethod
from typing import Any, Optional, Union
from typing import Any
import torch
from torch.nn.parameter import Parameter, UninitializedParameter
@@ -187,7 +187,7 @@ class LinearMethodBase(QuantizeMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
@@ -252,7 +252,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
@@ -276,8 +276,8 @@ class LinearBase(CustomOp):
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
@@ -295,7 +295,7 @@ class LinearBase(CustomOp):
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
self.return_bias = return_bias
@@ -333,8 +333,8 @@ class ReplicatedLinear(LinearBase):
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
@@ -409,7 +409,7 @@ class ReplicatedLinear(LinearBase):
def forward(
self,
x: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
@@ -461,9 +461,9 @@ class ColumnParallelLinear(LinearBase):
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
output_sizes: list[int] | None = None,
prefix: str = "",
*,
return_bias: bool = True,
@@ -574,7 +574,7 @@ class ColumnParallelLinear(LinearBase):
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
@@ -633,8 +633,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
@@ -662,7 +662,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None,
loaded_shard_id: int | None = None,
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
@@ -838,7 +838,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None,
loaded_shard_id: int | None = None,
):
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
@@ -914,11 +914,11 @@ class QKVParallelLinear(ColumnParallelLinear):
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
total_num_kv_heads: int | None = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
@@ -1027,7 +1027,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
loaded_shard_id: str | None = None,
):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
@@ -1071,7 +1071,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
loaded_shard_id: str | None = None,
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
@@ -1296,9 +1296,9 @@ class RowParallelLinear(LinearBase):
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
params_dtype: torch.dtype | None = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
@@ -1405,7 +1405,7 @@ class RowParallelLinear(LinearBase):
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel:
input_parallel = input_
else:

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
from vllm.distributed import (
@@ -28,10 +26,10 @@ class LogitsProcessor(CustomOp):
def __init__(
self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
org_vocab_size: int | None = None,
scale: float = 1.0,
logits_as_input: bool = False,
soft_cap: Optional[float] = None,
soft_cap: float | None = None,
) -> None:
"""
Args:
@@ -53,8 +51,8 @@ class LogitsProcessor(CustomOp):
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
embedding_bias: torch.Tensor | None = None,
) -> torch.Tensor | None:
if self.logits_as_input:
logits = hidden_states
else:
@@ -88,8 +86,8 @@ class LogitsProcessor(CustomOp):
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
embedding_bias: torch.Tensor | None,
) -> torch.Tensor | None:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -87,8 +87,8 @@ class MiniMaxText01RMSNormTP(CustomOp):
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
@@ -102,7 +102,7 @@ class MiniMaxText01LinearKernel:
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: Optional[int] = None,
layer_idx: int | None = None,
**kwargs,
) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
@@ -154,9 +154,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -68,8 +68,8 @@ class MambaMixer(MambaBase, CustomOp):
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -410,7 +410,7 @@ class MambaMixer(MambaBase, CustomOp):
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
def _time_proj_bias(self) -> torch.Tensor | None:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
return None
@@ -423,8 +423,8 @@ class PrefillDecodeSplit(NamedTuple):
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: Optional[torch.Tensor]
has_initial_states_p: Optional[torch.Tensor]
query_start_loc_p: torch.Tensor | None
has_initial_states_p: torch.Tensor | None
def split_batch_to_prefill_and_decode(
@@ -432,7 +432,7 @@ def split_batch_to_prefill_and_decode(
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: Optional[torch.Tensor],
has_initial_states: torch.Tensor | None,
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -138,7 +138,7 @@ class Mixer2RMSNormGated(CustomOp):
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
@@ -244,9 +244,9 @@ class MambaMixer2(MambaBase, CustomOp):
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -474,7 +474,7 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
):
pass
@@ -482,7 +482,7 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
):
torch.ops.vllm.mamba_mixer2(
hidden_states,
@@ -495,7 +495,7 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
):
forward_context = get_forward_context()
# attn_metadata contains metadata necessary for the mamba2 triton
@@ -904,7 +904,7 @@ def mamba_mixer2(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
@@ -915,7 +915,7 @@ def mamba_mixer2_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
) -> None:
return

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
@@ -14,7 +13,7 @@ class MambaStateDtypeCalculator:
@classmethod
def linear_attention_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires testing
@@ -26,7 +25,7 @@ class MambaStateDtypeCalculator:
@classmethod
def mamba1_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
@@ -37,7 +36,7 @@ class MambaStateDtypeCalculator:
@classmethod
def mamba2_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
@@ -48,7 +47,7 @@ class MambaStateDtypeCalculator:
@classmethod
def _mamba_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
@@ -63,7 +62,7 @@ class MambaStateDtypeCalculator:
@classmethod
def short_conv_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
@@ -72,7 +71,7 @@ class MambaStateDtypeCalculator:
@classmethod
def gated_delta_net_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)

View File

@@ -4,7 +4,6 @@
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from typing import Optional, Union
import numpy as np
import torch
@@ -469,17 +468,17 @@ def _causal_conv1d_fwd_kernel( # continuous batching
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Union[torch.Tensor, None],
bias: torch.Tensor | None,
conv_states: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
cache_indices: torch.Tensor | None = None,
has_initial_state: torch.Tensor | None = None,
activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID,
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
num_computed_tokens: Optional[torch.Tensor] = None,
block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
num_computed_tokens: torch.Tensor | None = None,
block_size_to_align=0,
metadata=None,
validate_data=False,
@@ -1071,15 +1070,15 @@ def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
activation: bool | str | None = None,
conv_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
query_start_loc: torch.Tensor | None = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
validate_data=False,
):
"""

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -38,8 +38,8 @@ class ShortConv(MambaBase, CustomOp):
config,
dim: int,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
prefix: str = "",
):
super().__init__()

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
@@ -19,14 +18,14 @@ class MLAModules:
kv_b_proj: torch.nn.Module
rotary_emb: torch.nn.Module
o_proj: torch.nn.Module
fused_qkv_a_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
indexer: Optional[torch.nn.Module]
fused_qkv_a_proj: torch.nn.Module | None
kv_a_proj_with_mqa: torch.nn.Module | None
q_a_layernorm: torch.nn.Module | None
q_b_proj: torch.nn.Module | None
q_proj: torch.nn.Module | None
indexer: torch.nn.Module | None
is_sparse: bool
topk_indices_buffer: Optional[torch.Tensor]
topk_indices_buffer: torch.Tensor | None
@CustomOp.register("multi_head_latent_attention")
@@ -55,11 +54,11 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
q_lora_rank: int | None,
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()

View File

@@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Mapping, Set
from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union
from typing import TypeVar
import torch
import torch.nn as nn
@@ -24,8 +24,8 @@ from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
logger = init_logger(__name__)
PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]],
[torch.Tensor | list[torch.Tensor], PoolingMetadata],
torch.Tensor | list[torch.Tensor],
]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
@@ -90,7 +90,7 @@ class Pooler(nn.Module, ABC):
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: Optional[ClassifierFn],
classifier: ClassifierFn | None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
@@ -118,14 +118,14 @@ class Pooler(nn.Module, ABC):
@abstractmethod
def forward(
self,
hidden_states: Union[list[torch.Tensor], torch.Tensor],
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
return pooling_metadata.prompt_lens
@@ -174,7 +174,7 @@ def get_classification_activation_function(config: PretrainedConfig):
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
function_name: str | None = None
if (
hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers
@@ -223,14 +223,14 @@ class PoolingMethod(nn.Module, ABC):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
) -> list[torch.Tensor] | torch.Tensor:
raise NotImplementedError
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
) -> list[torch.Tensor] | torch.Tensor:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
@@ -243,7 +243,7 @@ class CLSPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
) -> list[torch.Tensor] | torch.Tensor:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
@@ -259,7 +259,7 @@ class LastPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
) -> list[torch.Tensor] | torch.Tensor:
return hidden_states[pooling_cursor.last_token_indices_gpu]
@@ -271,7 +271,7 @@ class AllPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
) -> list[torch.Tensor] | torch.Tensor:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with ALL pooling"
)
@@ -290,7 +290,7 @@ class MeanPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
) -> list[torch.Tensor] | torch.Tensor:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
@@ -405,7 +405,7 @@ class PoolerHead(nn.Module):
def forward(
self,
pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
return self.activation(pooled_data)
@@ -418,14 +418,14 @@ class EmbeddingPoolerHead(PoolerHead):
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector: Optional[nn.Module] = (
self.projector: nn.Module | None = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
def forward(
self,
pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
if isinstance(pooled_data, list):
@@ -480,7 +480,7 @@ class RewardPoolerHead(PoolerHead):
def forward(
self,
pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
if isinstance(pooled_data, list):
@@ -541,7 +541,7 @@ class SimplePooler(Pooler):
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
@@ -560,9 +560,9 @@ class StepPooler(Pooler):
def extract_states(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
) -> list[torch.Tensor] | torch.Tensor:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
@@ -593,7 +593,7 @@ class StepPooler(Pooler):
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
@@ -621,8 +621,8 @@ class ClassifierPooler(Pooler):
def __init__(
self,
pooling: PoolingFn,
classifier: Optional[ClassifierFn],
act_fn: Optional[PoolerActivation] = None,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | None = None,
) -> None:
super().__init__()
@@ -631,7 +631,7 @@ class ClassifierPooler(Pooler):
self.pooling = pooling
self.classifier = classifier
self.act_fn = act_fn or PoolerClassify()
self.logit_bias: Optional[float] = (
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
@@ -641,7 +641,7 @@ class ClassifierPooler(Pooler):
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
@@ -695,7 +695,7 @@ class DispatchPooler(Pooler):
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any
import torch
@@ -46,8 +46,8 @@ class AutoRoundConfig(QuantizationConfig):
group_size: int,
sym: bool = True,
packing_format: str = "auto_round:auto_gptq",
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
extra_config: Optional[dict[str, Any]] = None,
block_name_to_quantize: str | list[str] | None = None,
extra_config: dict[str, Any] | None = None,
data_type: str = "int",
backend: str = "auto",
) -> None:

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from typing import Any, Union
import torch
@@ -34,7 +34,7 @@ class AWQConfig(QuantizationConfig):
weight_bits: int,
group_size: int,
zero_point: bool,
modules_to_not_convert: Optional[list[str]] = None,
modules_to_not_convert: list[str] | None = None,
) -> None:
super().__init__()
self.weight_bits = weight_bits
@@ -88,7 +88,7 @@ class AWQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
@@ -227,7 +227,7 @@ class AWQLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any, Optional
import torch
from torch.nn import Parameter
@@ -70,7 +71,7 @@ class AWQMarlinConfig(QuantizationConfig):
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[list[str]],
modules_to_not_convert: list[str] | None,
full_config: dict[str, Any],
) -> None:
super().__init__()
@@ -140,7 +141,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
) -> QuantizationMethods | None:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
@@ -360,7 +361,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_awq_marlin_linear(
input=x,
@@ -555,7 +556,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return None
def apply(
@@ -566,21 +567,21 @@ class AWQMoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:

View File

@@ -3,7 +3,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import torch
from torch import nn
@@ -105,7 +105,7 @@ class QuantizationConfig(ABC):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
) -> QuantizationMethods | None:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
@@ -135,7 +135,7 @@ class QuantizationConfig(ABC):
@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
) -> QuantizeMethodBase | None:
"""Get the quantize method to use for the quantized layer.
Args:
@@ -147,7 +147,7 @@ class QuantizationConfig(ABC):
"""
raise NotImplementedError
def get_cache_scale(self, name: str) -> Optional[str]:
def get_cache_scale(self, name: str) -> str | None:
return None
def apply_vllm_mapper( # noqa: B027

View File

@@ -45,10 +45,10 @@ class BitBLASConfig(QuantizationConfig):
def __init__(
self,
weight_bits: int,
group_size: Optional[int],
desc_act: Optional[bool],
is_sym: Optional[bool],
quant_method: Optional[str],
group_size: int | None,
desc_act: bool | None,
is_sym: bool | None,
quant_method: str | None,
lm_head_quantized: bool,
) -> None:
try:
@@ -160,7 +160,7 @@ class BitBLASConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
) -> QuantizationMethods | None:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = hf_quant_cfg.get(
@@ -469,7 +469,7 @@ class BitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any, Union
import torch
from packaging import version
@@ -41,7 +42,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[list[str]] = None,
llm_int8_skip_modules: list[str] | None = None,
llm_int8_threshold: float = 6.0,
) -> None:
super().__init__()
@@ -138,7 +139,7 @@ class BitsAndBytesConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
) -> Union["LinearMethodBase", "BitsAndBytesMoEMethod"] | None:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
@@ -268,7 +269,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.quant_config.load_in_8bit:
return self._apply_8bit_weight(layer, x, bias)
@@ -279,7 +280,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# only load the bitsandbytes module when needed
from bitsandbytes import MatmulLtState, matmul
@@ -359,7 +360,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
original_type = x.dtype
original_shape = x.shape
@@ -489,7 +490,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return None
def apply(
@@ -500,21 +501,21 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
assert self.fused_experts is None

View File

@@ -71,7 +71,7 @@ logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, QuantizationArgs] | None]
class CompressedTensorsConfig(QuantizationConfig):
@@ -82,9 +82,9 @@ class CompressedTensorsConfig(QuantizationConfig):
quant_format: str,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
transform_config: Optional[dict[str, Any]] = None,
kv_cache_scheme: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
transform_config: dict[str, Any] | None = None,
):
super().__init__()
self.ignore = ignore
@@ -524,7 +524,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
format: Optional[str] = None,
format: str | None = None,
) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
@@ -631,7 +631,7 @@ class CompressedTensorsConfig(QuantizationConfig):
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
def get_scheme(
self, layer: torch.nn.Module, layer_name: Optional[str] = None
self, layer: torch.nn.Module, layer_name: str | None = None
) -> Optional["CompressedTensorsScheme"]:
"""
compressed-tensors supports non uniform in the following way:
@@ -674,7 +674,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_targets = self.sparsity_scheme_map.keys() - set(
self.sparsity_ignore_list
)
sparsity_scheme: Optional[SparsityCompressionConfig] = None
sparsity_scheme: SparsityCompressionConfig | None = None
with suppress(ValueError):
matched_target = find_matched_target(
layer_name=layer_name,
@@ -723,7 +723,7 @@ class CompressedTensorsConfig(QuantizationConfig):
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
return scheme
def get_cache_scale(self, name: str) -> Optional[str]:
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
@@ -751,9 +751,9 @@ class CompressedTensorsConfig(QuantizationConfig):
@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
input_quant: Optional[QuantizationArgs],
sparsity_scheme: Optional[SparsityCompressionConfig] = None,
weight_quant: QuantizationArgs | None,
input_quant: QuantizationArgs | None,
sparsity_scheme: SparsityCompressionConfig | None = None,
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
@@ -853,7 +853,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
):
"""
Use the output of create_weights and the CompressedTensorsScheme
@@ -878,7 +878,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
def validate_kv_cache_scheme(kv_cache_scheme: dict[str, Any] | None):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM

View File

@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from collections.abc import Callable
from enum import Enum
from typing import Callable, Optional, Union
import torch
from compressed_tensors import CompressionFormat
@@ -372,7 +372,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False
)
def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin:
return None
elif not self.allow_flashinfer:
@@ -399,7 +399,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
return None
@@ -420,21 +420,21 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
@@ -913,7 +913,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale
)
def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or self.rocm_aiter_moe_enabled:
return None
else:
@@ -997,7 +997,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
return None
@@ -1022,21 +1022,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet."
@@ -1280,7 +1280,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return int8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
@@ -1297,21 +1297,21 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:
@@ -1604,7 +1604,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return None
def apply(
@@ -1615,21 +1615,21 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:
@@ -1856,7 +1856,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
assert self.num_bits == 4 or self.num_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
@@ -1880,21 +1880,21 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:
@@ -2092,7 +2092,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def _pack_matrix(
int4_as_int8_2d: torch.Tensor,
scales_2d: torch.Tensor,
bias_1d: Optional[torch.Tensor],
bias_1d: torch.Tensor | None,
in_features: int,
out_features: int,
) -> torch.Tensor:
@@ -2192,7 +2192,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
# CPU dynamic 4-bit MoE path does not use modular kernels or
# fused_experts; quant config is not needed.
return None
@@ -2205,20 +2205,20 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert activation in ("silu", "swigluoai", "swiglu"), (

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any
import torch
from compressed_tensors import CompressionFormat, ModelCompressor
@@ -42,9 +43,9 @@ class CompressedTensors24(CompressedTensorsScheme):
def __init__(
self,
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[dict[str, Any]] = None,
weight_quant: QuantizationArgs | None = None,
input_quant: QuantizationArgs | None = None,
model_compression_config: dict[str, Any] | None = None,
):
self.quantized = quantized
self.weight_quant = weight_quant
@@ -247,7 +248,7 @@ class CompressedTensors24(CompressedTensorsScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Returns the output tensor for the layer with 2:4

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Optional
import torch
@@ -33,7 +32,7 @@ class CompressedTensorsScheme(ABC):
@abstractmethod
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
):
"""
Run the forward pass for the particular scheme. This is where

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from torch.nn import Parameter
@@ -30,7 +30,7 @@ W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None):
def __init__(self, strategy: str, num_bits: int, group_size: int | None = None):
self.strategy = strategy
self.group_size = group_size
self.tile_size = 16
@@ -143,7 +143,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
layer.workspace = workspace
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
qweight = layer.weight_packed
meta = layer.meta

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
@@ -110,7 +110,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp4_marlin_linear(
input=x,

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
@@ -156,7 +156,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
out = run_nvfp4_emulations(

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from compressed_tensors.quantization import ActivationOrdering
@@ -41,9 +41,9 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None,
group_size: int | None = None,
symmetric: bool | None = True,
actorder: ActivationOrdering | None = None,
):
self.pack_factor = 32 // num_bits
self.strategy = strategy
@@ -178,6 +178,6 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
@@ -36,7 +36,7 @@ class CompressedTensorsW4A8Int(CompressedTensorsScheme):
self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
group_size: int | None = None,
is_static_input_scheme: bool = False,
input_symmetric: bool = True,
):
@@ -148,6 +148,6 @@ class CompressedTensorsW4A8Int(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationStrategy
@@ -125,7 +125,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp8_marlin_linear(
input=x,

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
@@ -179,7 +179,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply(

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationStrategy
@@ -120,6 +120,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from collections.abc import Callable
import torch
from compressed_tensors.quantization import ActivationOrdering
@@ -42,9 +42,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None,
group_size: int | None = None,
symmetric: bool | None = True,
actorder: ActivationOrdering | None = None,
):
self.pack_factor = 32 // num_bits
self.strategy = strategy
@@ -214,6 +214,6 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Generator
from collections.abc import Callable, Generator
from itertools import accumulate
from typing import Callable, Optional
import torch
from compressed_tensors.transform import (
@@ -38,7 +37,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
def from_schemes(
cls,
quant_method: LinearMethodBase,
quant_scheme: Optional[CompressedTensorsScheme],
quant_scheme: CompressedTensorsScheme | None,
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple],
) -> "CompressedTensorsLinearTransformMethod":
@@ -66,8 +65,8 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
self.input_tfms = input_tfms
self.output_tfms = output_tfms
self.input_transform: Optional[HadamardTransform] = None
self.output_transform: Optional[HadamardTransform] = None
self.input_transform: HadamardTransform | None = None
self.output_transform: HadamardTransform | None = None
def create_weights(
self,
@@ -151,7 +150,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.input_transform is not None:
x = self.input_transform(x)
@@ -194,7 +193,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
def get_linear_transform_schemes(
layer: torch.nn.Module,
layer_name: str,
transform_config: Optional[TransformConfig],
transform_config: TransformConfig | None,
packed_modules_mapping: dict[str, list[str]],
) -> tuple[
dict[int, TransformTuple], dict[int, TransformTuple]
@@ -226,7 +225,7 @@ def get_linear_transform_schemes(
def get_schemes_args(
transform_config: Optional[TransformConfig],
transform_config: TransformConfig | None,
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
if transform_config is None:
return

View File

@@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Hashable
from typing import Callable
from collections.abc import Callable, Hashable
import torch
from compressed_tensors.transform import (

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -17,7 +16,7 @@ __all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
def is_qutlass_fp4_scheme(
quant_scheme: Optional[CompressedTensorsScheme],
quant_scheme: CompressedTensorsScheme | None,
input_tfms: dict[int, TransformTuple],
) -> bool:
return (
@@ -60,6 +59,6 @@ class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError()

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -145,7 +144,7 @@ def triton_scaled_mm(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32,

View File

@@ -3,7 +3,6 @@
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Optional
import regex as re
from compressed_tensors import CompressionFormat
@@ -21,7 +20,7 @@ def is_activation_quantization_format(format: str) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
layer_name: str | None,
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
) -> bool:
@@ -84,7 +83,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
def find_matched_target(
layer_name: Optional[str],
layer_name: str | None,
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
@@ -134,7 +133,7 @@ def find_matched_target(
def _find_first_match(
value: str, targets: Iterable[str], check_contains: bool = False
) -> Optional[str]:
) -> str | None:
"""
Returns first element of target that matches value either
exactly or as a regex after 're:'. If check_contains is set to True,
@@ -176,7 +175,7 @@ def _match_fused_layer(
layer_name: str,
target_layers: Iterable[str],
fused_mapping: Mapping[str, list[str]],
) -> Optional[str]:
) -> str | None:
"""
Match a fused layer name to its corresponding individual layer in
target_layers. Returns first value in fused_mapping which matches targets
@@ -205,7 +204,7 @@ def _match_fused_layer(
]
# for each unfused component, find a match in targets
unfused_matches: list[Optional[str]] = []
unfused_matches: list[str | None] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):

View File

@@ -140,7 +140,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
weight = layer.weight
y = weight.ds_dequantize()

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any, Optional
import torch
@@ -129,7 +130,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return int8_w8a16_moe_quant_config(
w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None
)
@@ -142,21 +143,21 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:

View File

@@ -171,7 +171,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch.nn import Module
@@ -173,8 +174,8 @@ class Fp8Config(QuantizationConfig):
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[list[str]] = None,
weight_block_size: Optional[list[int]] = None,
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
) -> None:
super().__init__()
@@ -298,7 +299,7 @@ class Fp8Config(QuantizationConfig):
return Fp8KVCacheMethod(self)
return None
def get_cache_scale(self, name: str) -> Optional[str]:
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
@@ -530,7 +531,7 @@ class Fp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
@@ -584,12 +585,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fused_experts: Optional[mk.FusedMoEModularKernel] = None # type: ignore
self.fused_experts: mk.FusedMoEModularKernel | None = None # type: ignore
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
@@ -970,7 +971,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv
)
def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
or self.use_marlin
@@ -1043,7 +1044,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
return None
@@ -1069,21 +1070,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None

View File

@@ -3,7 +3,7 @@
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
from typing import Any, Optional
from typing import Any
import torch
from torch.nn.parameter import Parameter
@@ -36,7 +36,7 @@ class FPQuantConfig(QuantizationConfig):
forward_dtype: str = "mxfp4",
forward_method: str = "abs_max",
pseudoquantization: bool = False,
modules_to_not_convert: Optional[list[str]] = None,
modules_to_not_convert: list[str] | None = None,
) -> None:
super().__init__()
self.hadamard_group_size = hadamard_group_size
@@ -90,7 +90,7 @@ class FPQuantConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[LinearMethodBase]:
) -> LinearMethodBase | None:
if self.modules_to_not_convert is not None and any(
prefix.endswith(module) for module in self.modules_to_not_convert
):
@@ -233,7 +233,7 @@ class FPQuantLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return quantized_forward(
x,
@@ -381,7 +381,7 @@ def quantized_forward(
weight_scales: torch.Tensor,
weight_global_scale: torch.Tensor,
act_global_scale: torch.Tensor,
bias: Optional[torch.Tensor],
bias: torch.Tensor | None,
forward_hadamard_matrix: torch.Tensor,
forward_method: str,
forward_dtype: str,

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any, Optional
import gguf
import torch
@@ -35,7 +36,7 @@ logger = init_logger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, unquantized_modules: Optional[list[str]] = None) -> None:
def __init__(self, unquantized_modules: list[str] | None = None) -> None:
super().__init__()
self.unquantized_modules = unquantized_modules or []
@@ -307,7 +308,7 @@ def _apply_gguf_embedding(
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
if qweight_type in UNQUANTIZED_TYPES:
return torch.embedding(qweight, x)
@@ -330,7 +331,7 @@ def _apply_gguf_embedding_fake(
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
@@ -452,7 +453,7 @@ class GGUFLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
shard_id = layer.qweight.shard_id
@@ -558,7 +559,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return None
def apply(
@@ -569,21 +570,21 @@ class GGUFMoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:

View File

@@ -4,7 +4,7 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
@@ -48,9 +48,9 @@ class GPTQConfig(QuantizationConfig):
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
dynamic: dict[str, dict[str, int | bool]],
autoround_version: str = "",
modules_in_block_to_quantize: Optional[list[str]] = None,
modules_in_block_to_quantize: list[str] | None = None,
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
@@ -148,7 +148,7 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]:
) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None:
if isinstance(layer, FusedMoE):
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
from .moe_wna16 import MoeWNA16Config
@@ -170,7 +170,7 @@ class GPTQConfig(QuantizationConfig):
self.modules_in_block_to_quantize
)
def maybe_update_config(self, model_name: str, revision: Optional[str] = None):
def maybe_update_config(self, model_name: str, revision: str | None = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
@@ -345,7 +345,7 @@ class GPTQLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1])

View File

@@ -71,7 +71,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
group_size: int,
desc_act: bool,
is_sym: bool,
quant_method: Optional[str],
quant_method: str | None,
lm_head_quantized: bool,
) -> None:
try:
@@ -180,7 +180,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
) -> QuantizationMethods | None:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (
@@ -474,7 +474,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
if bias is not None:

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from copy import deepcopy
from typing import Any, Callable, Optional, Union
from typing import Any, Optional
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
@@ -103,9 +104,9 @@ class GPTQMarlinConfig(QuantizationConfig):
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
dynamic: dict[str, dict[str, int | bool]],
full_config: dict[str, Any],
modules_in_block_to_quantize: Optional[list[str]] = None,
modules_in_block_to_quantize: list[str] | None = None,
) -> None:
super().__init__()
if desc_act and group_size == -1:
@@ -211,7 +212,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
) -> QuantizationMethods | None:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
@@ -283,7 +284,7 @@ class GPTQMarlinConfig(QuantizationConfig):
self.modules_in_block_to_quantize
)
def maybe_update_config(self, model_name: str, revision: Optional[str] = None):
def maybe_update_config(self, model_name: str, revision: str | None = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
@@ -459,7 +460,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
@@ -714,7 +715,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return None
def apply(
@@ -725,21 +726,21 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:

View File

@@ -114,7 +114,7 @@ class GPTQMarlin24Config(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
) -> QuantizationMethods | None:
is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24"
is_valid_user_quant = (
@@ -287,7 +287,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
qweight = layer.B_24
meta = layer.B_meta

View File

@@ -45,7 +45,7 @@ class HQQMarlinConfig(QuantizationConfig):
self,
weight_bits: int,
group_size: int,
skip_modules: Optional[list[str]] = None,
skip_modules: list[str] | None = None,
) -> None:
super().__init__()
assert group_size == 64, "The only supported HQQ group size is currently 64."
@@ -327,7 +327,7 @@ class HQQMarlinMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
workspace = MarlinWorkspace(
self.output_size_per_partition,

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn.functional as F
@@ -30,9 +29,9 @@ class QuantFP8(CustomOp):
self,
static: bool,
group_shape: GroupShape,
num_token_padding: Optional[int] = None,
num_token_padding: int | None = None,
column_major_scales: bool = False,
use_ue8m0: Optional[bool] = None, # for Torch compile
use_ue8m0: bool | None = None, # for Torch compile
):
"""
:param static: static or dynamic quantization
@@ -64,8 +63,8 @@ class QuantFP8(CustomOp):
def forward_cuda(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
@@ -96,8 +95,8 @@ class QuantFP8(CustomOp):
def forward_native(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
):
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional
import torch
from packaging import version
@@ -50,9 +51,9 @@ class IPEXConfig(QuantizationConfig):
method: str,
weight_bits: int,
group_size: int,
modules_to_not_convert: Optional[list[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
modules_to_not_convert: list[str] | None = None,
desc_act: bool | None = None,
lm_head_quantized: bool | None = None,
) -> None:
super().__init__()
self.method = method
@@ -122,7 +123,7 @@ class IPEXConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]:
) -> QuantizationMethods | None:
if not current_platform.is_cpu() and not current_platform.is_xpu():
return None
@@ -206,7 +207,7 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
@@ -275,7 +276,7 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
@@ -299,7 +300,7 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
weight = layer.weight.data
weight_scale = layer.weight_scale.data
@@ -410,7 +411,7 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
) -> FusedMoEQuantConfig | None:
return None
def apply(
@@ -421,20 +422,20 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
return layer.ipex_fusion(
x,

View File

@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@@ -20,7 +20,7 @@ class MPLinearLayerConfig:
group_size: int
zero_points: bool
has_g_idx: bool
out_type: Optional[torch.dtype] = None
out_type: torch.dtype | None = None
class MPLinearKernel(ABC):
@@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
@@ -39,8 +39,8 @@ class MPLinearKernel(ABC):
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None,
w_zp_param_name: str | None = None,
w_gidx_param_name: str | None = None,
) -> None:
assert self.can_implement(c)
self.config = c
@@ -62,12 +62,12 @@ class MPLinearKernel(ABC):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def _transform_param(
self, layer: torch.nn.Module, name: Optional[str], fn: Callable
self, layer: torch.nn.Module, name: str | None, fn: Callable
) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
@@ -83,8 +83,8 @@ class MPLinearKernel(ABC):
) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor], # w_gidx
torch.Tensor | None, # w_zp,
torch.Tensor | None, # w_gidx
]:
return (
getattr(layer, self.w_q_name),

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel,
@@ -48,7 +46,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
def choose_mp_linear_kernel(
config: MPLinearLayerConfig, compute_capability: Optional[int] = None
config: MPLinearLayerConfig, compute_capability: int | None = None
) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -22,7 +21,7 @@ class AllSparkLinearKernel(MPLinearKernel):
return 80
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"
@@ -87,7 +86,7 @@ class AllSparkLinearKernel(MPLinearKernel):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
gemm_args = self.gemm_args

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from packaging import version
@@ -44,9 +43,9 @@ class BitBLASLinearKernel(MPLinearKernel):
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None,
bitblas_quant_config: Optional[QuantizationConfig] = None,
w_zp_param_name: str | None = None,
w_gidx_param_name: str | None = None,
bitblas_quant_config: QuantizationConfig | None = None,
):
self.quant_config = bitblas_quant_config
super().__init__(
@@ -57,7 +56,7 @@ class BitBLASLinearKernel(MPLinearKernel):
self,
b_q_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: Optional[torch.Tensor] = None,
qzeros: torch.Tensor | None = None,
):
from bitblas.quantization.utils import general_compress
@@ -82,7 +81,7 @@ class BitBLASLinearKernel(MPLinearKernel):
# qzeros should be de-quantized to int zeros.
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
zeros: Optional[torch.Tensor] = None
zeros: torch.Tensor | None = None
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
if zeros_mode == "original":
zeros = intzeros.to(torch.float16).contiguous()
@@ -113,7 +112,7 @@ class BitBLASLinearKernel(MPLinearKernel):
return 70
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
is_bitblas_installed = True
try:

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec
from typing import Final, Optional
from typing import Final
import torch
@@ -26,7 +26,7 @@ class ConchLinearKernel(MPLinearKernel):
return 80
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
error_msg = (
f"Weight type ({c.weight_type}) not supported by "
@@ -76,7 +76,7 @@ class ConchLinearKernel(MPLinearKernel):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from conch.ops.quantization.gemm import mixed_precision_gemm

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -26,7 +25,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
return 90
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "CUTLASS only supported on CUDA"
@@ -95,7 +94,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -20,7 +19,7 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
return 1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Only CPU is supported"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
@@ -95,7 +94,7 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
@@ -25,7 +24,7 @@ class ExllamaLinearKernel(MPLinearKernel):
return 60
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
@@ -137,7 +136,7 @@ class ExllamaLinearKernel(MPLinearKernel):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config

Some files were not shown because too many files have changed in this diff Show More