Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
"""Custom activation functions."""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -486,7 +485,7 @@ class ScaledActivation(nn.Module):
|
||||
act_module: nn.Module,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
|
||||
@@ -4,7 +4,7 @@ import contextlib
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -138,7 +138,7 @@ def matmul_kernel_persistent(
|
||||
|
||||
|
||||
def matmul_persistent(
|
||||
a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None
|
||||
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
||||
):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
@@ -375,7 +375,7 @@ def mean_dim(
|
||||
input: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: Union[torch.dtype, None] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
@@ -475,9 +475,7 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def mean_batch_invariant(
|
||||
input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None
|
||||
):
|
||||
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@@ -32,7 +31,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
@@ -86,7 +85,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -119,7 +118,7 @@ def chunk_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -257,12 +256,12 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
g: torch.Tensor | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -144,9 +143,9 @@ def chunk_fwd_o(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
g: torch.Tensor | None = None, # cumsum of log decay
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -104,8 +103,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
g_cumsum: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -163,9 +162,9 @@ def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
@@ -200,9 +199,9 @@ def chunk_local_cumsum_vector(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T, S = g.shape
|
||||
@@ -248,9 +247,9 @@ def chunk_local_cumsum(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if not head_first and g.shape[1] < g.shape[2]:
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -169,9 +168,9 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
@@ -248,9 +247,9 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
o, final_state = fused_recurrent_gated_delta_rule_fwd(
|
||||
@@ -280,9 +279,9 @@ def fused_recurrent_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -90,7 +89,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
|
||||
|
||||
def l2norm_fwd(
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -324,10 +323,10 @@ class LayerNormGated(nn.Module):
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
group_size: int | None = None,
|
||||
norm_before_gate: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
@@ -364,10 +363,10 @@ class RMSNormGated(nn.Module):
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
group_size: int | None = None,
|
||||
norm_before_gate: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -407,7 +406,7 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
@input_guard
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -11,8 +11,9 @@ import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,7 +44,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
|
||||
cache_entries: tuple[tuple | None, dict | None, Any] = []
|
||||
cache_size = 4
|
||||
|
||||
@functools.wraps(fn)
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -123,7 +122,7 @@ def recompute_w_u_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
cu_seqlens: torch.LongTensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -31,7 +31,7 @@ def override_config(config):
|
||||
_config = old_config
|
||||
|
||||
|
||||
def get_config() -> Optional[dict[str, Any]]:
|
||||
def get_config() -> dict[str, Any] | None:
|
||||
return _config
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -259,7 +258,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# FIXME (varun): We should be able to dispatch only from the leader
|
||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||
@@ -282,12 +281,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert expert_tokens_meta is not None
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -110,7 +109,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_metadata: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
@@ -148,12 +147,12 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
experts = (
|
||||
|
||||
@@ -34,8 +34,8 @@ def _get_config_dtype_str(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
ocp_mx_scheme: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
@@ -60,16 +60,16 @@ def _get_config_dtype_str(
|
||||
|
||||
|
||||
def _quant_flags_to_group_shape(
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[Optional[GroupShape], Optional[GroupShape]]:
|
||||
block_shape: list[int] | None,
|
||||
) -> tuple[GroupShape | None, GroupShape | None]:
|
||||
"""
|
||||
Convert MoE quantization flags into more generic GroupShapes.
|
||||
"""
|
||||
a_shape: Optional[GroupShape]
|
||||
w_shape: Optional[GroupShape]
|
||||
a_shape: GroupShape | None
|
||||
w_shape: GroupShape | None
|
||||
if block_shape is not None:
|
||||
assert not per_act_token_quant
|
||||
assert not per_out_ch_quant
|
||||
@@ -100,7 +100,7 @@ class FusedMoEQuantDesc:
|
||||
# The quantized type of this parameters. None means unquantized or
|
||||
# already quantized.
|
||||
# TODO (bnell): use scalar_type instead of Union.
|
||||
dtype: Union[torch.dtype, str, None] = None
|
||||
dtype: torch.dtype | str | None = None
|
||||
|
||||
# A field that describes the quantization group shape, from quant_utils.py.
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
@@ -109,7 +109,7 @@ class FusedMoEQuantDesc:
|
||||
# * (128, 128) for 128x128 deepseek style block quantization
|
||||
# * (1, 128) for deepseek style activation quantization
|
||||
# (i.e. per-token-per-group)
|
||||
shape: Optional[GroupShape] = None
|
||||
shape: GroupShape | None = None
|
||||
|
||||
# Quantization scales.
|
||||
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
|
||||
@@ -117,13 +117,13 @@ class FusedMoEQuantDesc:
|
||||
|
||||
# Quantization alphas or gscales, used for nvfp4 types.
|
||||
# TODO(bnell): put some of these in subclasses
|
||||
alpha_or_gscale: Optional[torch.Tensor] = None
|
||||
alpha_or_gscale: torch.Tensor | None = None
|
||||
|
||||
# Zero points for int4/int8 types
|
||||
zp: Optional[torch.Tensor] = None
|
||||
zp: torch.Tensor | None = None
|
||||
|
||||
# Biases for GPT triton MoE
|
||||
bias: Optional[torch.Tensor] = None
|
||||
bias: torch.Tensor | None = None
|
||||
|
||||
|
||||
# TODO(bnell): have subclasses for specific moe methods?
|
||||
@@ -179,7 +179,7 @@ class FusedMoEQuantConfig:
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
return self._a1.dtype
|
||||
|
||||
@property
|
||||
@@ -203,7 +203,7 @@ class FusedMoEQuantConfig:
|
||||
return self._a1.shape == GroupShape.PER_TENSOR
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
def block_shape(self) -> list[int] | None:
|
||||
if (
|
||||
self._a1.shape is not None
|
||||
and self._a1.shape != GroupShape.PER_TENSOR
|
||||
@@ -218,34 +218,34 @@ class FusedMoEQuantConfig:
|
||||
return self.block_shape is not None
|
||||
|
||||
@property
|
||||
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||
def a1_scale(self) -> torch.Tensor | None:
|
||||
assert self._a1.scale is None or isinstance(self._a1.scale, torch.Tensor)
|
||||
return self._a1.scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||
def a1_gscale(self) -> torch.Tensor | None:
|
||||
return self._a1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||
def a2_scale(self) -> torch.Tensor | None:
|
||||
assert self._a2.scale is None or isinstance(self._a2.scale, torch.Tensor)
|
||||
return self._a2.scale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||
def a2_gscale(self) -> torch.Tensor | None:
|
||||
return self._a2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||
def w1_scale(self) -> torch.Tensor | None:
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale, torch.Tensor)
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||
def w1_zp(self) -> torch.Tensor | None:
|
||||
return self._w1.zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||
def w1_bias(self) -> torch.Tensor | None:
|
||||
return self._w1.bias
|
||||
|
||||
@property
|
||||
@@ -254,20 +254,20 @@ class FusedMoEQuantConfig:
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||
def g1_alphas(self) -> torch.Tensor | None:
|
||||
return self._w1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||
def w2_scale(self) -> torch.Tensor | None:
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale, torch.Tensor)
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||
def w2_zp(self) -> torch.Tensor | None:
|
||||
return self._w2.zp
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||
def w2_bias(self) -> torch.Tensor | None:
|
||||
return self._w2.bias
|
||||
|
||||
@property
|
||||
@@ -276,7 +276,7 @@ class FusedMoEQuantConfig:
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||
def g2_alphas(self) -> torch.Tensor | None:
|
||||
return self._w2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
@@ -296,7 +296,7 @@ class FusedMoEQuantConfig:
|
||||
return self._a1.dtype is None and self._w1.dtype == "int4"
|
||||
|
||||
@property
|
||||
def ocp_mx_scheme(self) -> Union[str, None]:
|
||||
def ocp_mx_scheme(self) -> str | None:
|
||||
if not hasattr(self, "_ocp_mx_scheme"):
|
||||
if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or (
|
||||
self._w1.dtype is not None and not isinstance(self._w1.dtype, str)
|
||||
@@ -322,7 +322,7 @@ class FusedMoEQuantConfig:
|
||||
def use_nvfp4_w4a4(self) -> bool:
|
||||
return self.quant_dtype == "nvfp4"
|
||||
|
||||
def config_name(self, dtype: torch.dtype) -> Optional[str]:
|
||||
def config_name(self, dtype: torch.dtype) -> str | None:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
@@ -340,7 +340,7 @@ class FusedMoEQuantConfig:
|
||||
self,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
) -> tuple[int, int] | None:
|
||||
"""
|
||||
Construct the proper activation scale shape for this
|
||||
config.
|
||||
@@ -363,7 +363,7 @@ class FusedMoEQuantConfig:
|
||||
num_experts: int,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int, int]]:
|
||||
) -> tuple[int, int, int] | None:
|
||||
"""
|
||||
Construct the proper activation batched scale shape for this
|
||||
config, e.g. (num experts, *scale_shape).
|
||||
@@ -377,23 +377,23 @@ class FusedMoEQuantConfig:
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
g1_alphas: Optional[torch.Tensor] = None,
|
||||
g2_alphas: Optional[torch.Tensor] = None,
|
||||
a1_gscale: Optional[torch.Tensor] = None,
|
||||
a2_gscale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
weight_dtype: Union[torch.dtype, str, None] = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
g1_alphas: torch.Tensor | None = None,
|
||||
g2_alphas: torch.Tensor | None = None,
|
||||
a1_gscale: torch.Tensor | None = None,
|
||||
a2_gscale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
weight_dtype: torch.dtype | str | None = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
@@ -457,11 +457,11 @@ class FusedMoEQuantConfig:
|
||||
def fp8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for fp8 activations and fp8 weights.
|
||||
@@ -481,8 +481,8 @@ def fp8_w8a8_moe_quant_config(
|
||||
def int8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
a1_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
@@ -503,8 +503,8 @@ def int8_w8a8_moe_quant_config(
|
||||
def mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations and mxfp4 weights.
|
||||
@@ -521,12 +521,12 @@ def ocp_mx_moe_quant_config(
|
||||
quant_dtype: str,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
weight_dtype: Optional[str] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
weight_dtype: str | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
@@ -575,9 +575,9 @@ def nvfp4_moe_quant_config(
|
||||
def int4_w4a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_zp: torch.Tensor | None,
|
||||
w2_zp: torch.Tensor | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int4 weights.
|
||||
@@ -595,9 +595,9 @@ def int4_w4a16_moe_quant_config(
|
||||
def int8_w8a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_zp: torch.Tensor | None,
|
||||
w2_zp: torch.Tensor | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int8 weights.
|
||||
@@ -613,8 +613,8 @@ def int8_w8a16_moe_quant_config(
|
||||
|
||||
|
||||
def biased_moe_quant_config(
|
||||
w1_bias: Optional[torch.Tensor],
|
||||
w2_bias: Optional[torch.Tensor],
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations with biases.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
@@ -33,7 +33,7 @@ def grouped_topk(
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -88,12 +88,12 @@ def select_experts(
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
@@ -147,14 +147,14 @@ class IPEXFusedMOE:
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
@@ -189,14 +189,14 @@ class SGLFusedMOE:
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
@@ -247,14 +247,14 @@ class CPUFusedMOE:
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CUTLASS based Fused MoE kernels."""
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
@@ -35,23 +35,23 @@ def run_cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
activation_callable: Callable,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_num_tokens: torch.Tensor | None,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_batched_format: bool,
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor | None,
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
@@ -249,7 +249,7 @@ def run_cutlass_moe_fp8(
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -278,12 +278,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
@@ -331,7 +331,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -377,7 +377,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
@@ -390,7 +390,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
num_dispatchers: int,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -435,7 +435,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
@@ -457,7 +457,7 @@ def cutlass_moe_fp8(
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
activation: str = "silu",
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
) -> torch.Tensor:
|
||||
@@ -768,7 +768,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
@@ -793,12 +793,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor], # unused
|
||||
a2_scale: Optional[torch.Tensor], # unused
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None, # unused
|
||||
a2_scale: torch.Tensor | None, # unused
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
|
||||
@@ -839,7 +839,7 @@ def cutlass_moe_fp4(
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert expert_map is None, (
|
||||
@@ -896,7 +896,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
inplace: bool,
|
||||
activation: str,
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
) -> bool:
|
||||
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
|
||||
return N % 128 == 0 and K % 128 == 0
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@@ -204,7 +203,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
assert self.block_shape is not None
|
||||
block_m = self.block_shape[0]
|
||||
@@ -228,12 +227,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert a1q_scale is not None
|
||||
@@ -284,7 +283,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q_scale: torch.Tensor | None = None
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(
|
||||
act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out
|
||||
)
|
||||
@@ -317,9 +316,9 @@ def deep_gemm_moe_fp8(
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input=False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,6 @@ and updated to fit vllm needs and terminology.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -39,7 +38,7 @@ def compute_aligned_M(
|
||||
num_topk: int,
|
||||
local_num_experts: int,
|
||||
alignment: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
):
|
||||
if (expert_tokens_meta is not None) and (
|
||||
expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
@@ -175,7 +174,7 @@ def ep_scatter(
|
||||
recv_x_scale: torch.Tensor,
|
||||
recv_topk: torch.Tensor,
|
||||
num_recv_tokens_per_expert: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
expert_start_loc: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
output_tensor_scale: torch.Tensor,
|
||||
@@ -305,7 +304,7 @@ def ep_gather(
|
||||
recv_topk_ids: torch.Tensor,
|
||||
recv_topk_weight: torch.Tensor,
|
||||
input_index: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
output_tensor: torch.Tensor,
|
||||
):
|
||||
num_warps = 2
|
||||
@@ -346,9 +345,9 @@ def deepgemm_moe_permute(
|
||||
aq_scale: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
aq_out: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
aq_out: torch.Tensor | None = None,
|
||||
):
|
||||
assert aq.ndim == 2
|
||||
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
|
||||
@@ -415,7 +414,7 @@ def deepgemm_unpermute_and_reduce(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
inv_perm: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
return ep_gather(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@@ -77,18 +77,18 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.int64
|
||||
|
||||
def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
|
||||
def _get_dispatch_config(self) -> deep_ep.Config | None:
|
||||
if self.num_dispatchers_ not in self.available_rank_configs:
|
||||
return None
|
||||
return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_)
|
||||
|
||||
def _get_combine_config(self) -> Optional[deep_ep.Config]:
|
||||
def _get_combine_config(self) -> deep_ep.Config | None:
|
||||
if self.num_dispatchers_ not in self.available_rank_configs:
|
||||
return None
|
||||
return deep_ep.Buffer.get_combine_config(self.num_dispatchers_)
|
||||
@@ -96,11 +96,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def _do_dispatch(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor],
|
||||
token_scales: torch.Tensor | None,
|
||||
rank_topk_ids: torch.Tensor,
|
||||
rank_topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a1_scale: torch.Tensor | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> Callable:
|
||||
has_scales = token_scales is not None
|
||||
@@ -175,12 +175,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self,
|
||||
event: deep_ep.EventOverlap,
|
||||
has_scales: bool,
|
||||
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
expert_topk_ids: Optional[torch.Tensor],
|
||||
token_data: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
|
||||
expert_topk_ids: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
expert_num_tokens_per_expert_list: list[int],
|
||||
expert_topk_weights: Optional[torch.Tensor],
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
expert_topk_weights: torch.Tensor | None,
|
||||
a1_scale: torch.Tensor | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
if event.event is not None:
|
||||
@@ -249,7 +249,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.ReceiverType:
|
||||
@@ -294,7 +294,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
@@ -318,7 +318,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
do_async: bool,
|
||||
) -> Optional[Callable]:
|
||||
) -> Callable | None:
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
handle = self.handles[a2a_idx]
|
||||
assert handle is not None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@@ -67,7 +67,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handles: list[Optional[tuple]] = [None, None]
|
||||
self.handles: list[tuple | None] = [None, None]
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
@@ -80,18 +80,18 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return self.max_tokens_per_rank
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.int64
|
||||
|
||||
def _do_quant(
|
||||
self,
|
||||
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
a1_dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.use_fp8_dispatch:
|
||||
block_k = (
|
||||
quant_config.block_shape[1]
|
||||
@@ -137,7 +137,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
@@ -200,9 +200,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
expert_num_tokens: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a1_scale: torch.Tensor | None,
|
||||
a1_dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
@@ -220,7 +220,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -96,7 +95,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
@@ -133,13 +132,13 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: Optional[bool],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert activation == "silu", (
|
||||
"Only activation silu is supported in FlashInferExperts"
|
||||
@@ -207,7 +206,7 @@ def flashinfer_cutlass_moe_fp4(
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
@@ -242,7 +241,7 @@ def flashinfer_cutlass_moe(
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
tp_rank: int = 0,
|
||||
tp_size: int = 1,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -39,10 +38,10 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
@@ -89,7 +88,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
@@ -164,7 +163,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -105,7 +104,7 @@ direct_register_custom_op(
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
@@ -115,8 +114,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
@@ -163,7 +162,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
@@ -173,8 +172,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused batched MoE kernel."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@@ -370,8 +368,8 @@ def invoke_moe_batched_triton_kernel(
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
@@ -379,7 +377,7 @@ def invoke_moe_batched_triton_kernel(
|
||||
use_int4_w4a16: bool,
|
||||
config: dict[str, int],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
assert not use_int4_w4a16
|
||||
max_num_tokens = A.size(1)
|
||||
@@ -500,10 +498,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
@@ -518,7 +516,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
@@ -674,7 +672,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
@@ -701,12 +699,12 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert hidden_states.dim() == 3
|
||||
@@ -754,15 +752,15 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def batched_moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
num_tokens: int,
|
||||
E: int,
|
||||
N: int,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
qtype: Optional[torch.dtype],
|
||||
qtype: torch.dtype | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing():
|
||||
# Note: this does a bunch of extra work because expert_num_tokens is
|
||||
# ignored but it does support torch.compile + cudagraphs.
|
||||
@@ -868,7 +866,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
@@ -888,12 +886,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused MoE utilities for GPTQ."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -28,31 +26,31 @@ def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
bias1: Optional[torch.Tensor],
|
||||
bias2: Optional[torch.Tensor],
|
||||
bias1: torch.Tensor | None,
|
||||
bias2: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: Optional[torch.Tensor],
|
||||
gating_output: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
activation: Optional[str] = "silu",
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
global_scale1: Optional[torch.Tensor] = None,
|
||||
global_scale2: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
intermediate_cache13: Optional[torch.Tensor] = None,
|
||||
intermediate_cache2: Optional[torch.Tensor] = None,
|
||||
activation: str | None = "silu",
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
g_idx2: torch.Tensor | None = None,
|
||||
sort_indices1: torch.Tensor | None = None,
|
||||
sort_indices2: torch.Tensor | None = None,
|
||||
w1_zeros: torch.Tensor | None = None,
|
||||
w2_zeros: torch.Tensor | None = None,
|
||||
workspace: torch.Tensor | None = None,
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -249,26 +247,26 @@ def fused_marlin_moe_fake(
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: Optional[torch.Tensor],
|
||||
gating_output: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
global_scale1: Optional[torch.Tensor] = None,
|
||||
global_scale2: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
intermediate_cache13: Optional[torch.Tensor] = None,
|
||||
intermediate_cache2: Optional[torch.Tensor] = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
g_idx2: torch.Tensor | None = None,
|
||||
sort_indices1: torch.Tensor | None = None,
|
||||
sort_indices2: torch.Tensor | None = None,
|
||||
w1_zeros: torch.Tensor | None = None,
|
||||
w2_zeros: torch.Tensor | None = None,
|
||||
workspace: torch.Tensor | None = None,
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
@@ -341,7 +339,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Modular Kernel provisions output buffer from workspace1. However in
|
||||
# the fused_marlin_moe() function, the final torch.sum(), is defined
|
||||
@@ -374,12 +372,12 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert self.w1_scale is not None
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -539,10 +540,10 @@ def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: Optional[torch.Tensor],
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
@@ -555,8 +556,8 @@ def invoke_fused_moe_kernel(
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
B_bias: Optional[torch.Tensor] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
@@ -808,7 +809,7 @@ def zero_experts_compute_triton(
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
||||
def get_config_file_name(
|
||||
E: int, N: int, dtype: Optional[str], block_shape: Optional[list[int]] = None
|
||||
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
|
||||
) -> str:
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
@@ -823,10 +824,10 @@ def get_config_file_name(
|
||||
def get_moe_configs(
|
||||
E: int,
|
||||
N: int,
|
||||
dtype: Optional[str],
|
||||
block_n: Optional[int] = None,
|
||||
block_k: Optional[int] = None,
|
||||
) -> Optional[dict[int, Any]]:
|
||||
dtype: str | None,
|
||||
block_n: int | None = None,
|
||||
block_k: int | None = None,
|
||||
) -> dict[int, Any] | None:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
@@ -965,8 +966,8 @@ def get_default_config(
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
dtype: str | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
@@ -1016,9 +1017,9 @@ def try_get_optimal_moe_config(
|
||||
w1_shape: tuple[int, ...],
|
||||
w2_shape: tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
dtype: str | None,
|
||||
M: int,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
from vllm.model_executor.layers.fused_moe import get_config
|
||||
|
||||
@@ -1076,7 +1077,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
@@ -1135,7 +1136,7 @@ def grouped_topk(
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
@@ -1211,7 +1212,7 @@ def eplb_map_to_physical_and_record(
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the logical expert ids to physical expert ids
|
||||
@@ -1326,19 +1327,19 @@ def inplace_fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1381,19 +1382,19 @@ def inplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -1423,19 +1424,19 @@ def outplace_fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1477,19 +1478,19 @@ def outplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -1534,8 +1535,8 @@ def fused_experts(
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -1625,8 +1626,8 @@ GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
ocp_mx_scheme: Optional[str],
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
ocp_mx_scheme: str | None,
|
||||
) -> None | torch.dtype | str:
|
||||
"""
|
||||
Get the quantization type based on the quantization strategy flags.
|
||||
We don't have a quant_config at this point so we need to work backwards.
|
||||
@@ -1660,19 +1661,19 @@ def fused_experts_impl(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
@@ -1964,7 +1965,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M, topk, max(N // 2, K))
|
||||
workspace2 = (M, topk, max(N, K))
|
||||
@@ -1981,12 +1982,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
@@ -2074,7 +2075,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q_scale: torch.Tensor | None = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -80,10 +79,10 @@ def triton_kernel_moe_forward(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
activation: str = "silu",
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
routing_data, gather_idx, scatter_idx = routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
@@ -115,13 +114,13 @@ def triton_kernel_fused_experts(
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
a1q_scale: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
@@ -261,7 +260,7 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
workspace1 = (M, K)
|
||||
@@ -279,12 +278,12 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
if expert_map is not None:
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union, get_args, overload
|
||||
from typing import Literal, get_args, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -70,15 +70,15 @@ if current_platform.is_cuda_alike():
|
||||
)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
||||
FusedMoEPrepareAndFinalize = None # type: ignore
|
||||
FusedMoEPermuteExpertsUnpermute = object # type: ignore
|
||||
FusedMoEPrepareAndFinalize = object # type: ignore
|
||||
|
||||
def _eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
indices_type: Optional[torch.dtype],
|
||||
indices_type: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
# CPU fallback: no EPLB so just return as is
|
||||
return topk_ids
|
||||
@@ -110,8 +110,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.moe = moe
|
||||
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
self.fused_experts: Optional[FusedMoEModularKernel] = None
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.fused_experts: FusedMoEModularKernel | None = None
|
||||
self.topk_indices_dtype = None
|
||||
|
||||
@abstractmethod
|
||||
@@ -139,12 +139,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
@staticmethod
|
||||
def _maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||
|
||||
# TODO: could allow this now
|
||||
assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py"
|
||||
@@ -229,7 +229,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
return prepare_finalize
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||
if self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
self.moe, self.moe_quant_config
|
||||
@@ -280,7 +280,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
@abstractmethod
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@@ -296,21 +296,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -368,7 +368,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
)
|
||||
self.flashinfer_cutlass_moe = None # type: ignore
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
@@ -532,21 +532,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
@@ -578,7 +578,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
@@ -595,21 +595,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
|
||||
@@ -705,21 +705,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
@@ -754,21 +754,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
enable_eplb is not False
|
||||
or expert_load_view is not None
|
||||
@@ -795,21 +795,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
@@ -860,7 +860,7 @@ def determine_expert_map(
|
||||
ep_rank: int,
|
||||
global_num_experts: int,
|
||||
expert_placement_strategy: ExpertPlacementStrategy = "linear",
|
||||
) -> tuple[int, Optional[torch.Tensor]]:
|
||||
) -> tuple[int, torch.Tensor | None]:
|
||||
"""
|
||||
Calculates how many experts should be assigned to each rank for EP and
|
||||
creates a mapping from global to local expert index. Experts are
|
||||
@@ -941,7 +941,7 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
||||
def maybe_roundup_hidden_size(
|
||||
hidden_size: int,
|
||||
act_dtype: torch.dtype,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
quant_config: QuantizationConfig | None,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -1016,30 +1016,30 @@ class FusedMoE(CustomOp):
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
num_expert_group: int | None = None,
|
||||
topk_group: int | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
tp_size: int | None = None,
|
||||
ep_size: int | None = None,
|
||||
dp_size: int | None = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
zero_expert_num: Optional[int] = 0,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
|
||||
zero_expert_num: int | None = 0,
|
||||
zero_expert_type: str | None = None,
|
||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@@ -1092,9 +1092,9 @@ class FusedMoE(CustomOp):
|
||||
self.layer_name = prefix
|
||||
|
||||
self.enable_eplb = enable_eplb
|
||||
self.expert_load_view: Optional[torch.Tensor] = None
|
||||
self.logical_to_physical_map: Optional[torch.Tensor] = None
|
||||
self.logical_replica_count: Optional[torch.Tensor] = None
|
||||
self.expert_load_view: torch.Tensor | None = None
|
||||
self.logical_to_physical_map: torch.Tensor | None = None
|
||||
self.logical_replica_count: torch.Tensor | None = None
|
||||
|
||||
# Determine expert maps
|
||||
if self.use_ep:
|
||||
@@ -1128,7 +1128,7 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
expert_placement_strategy = "linear"
|
||||
|
||||
self.expert_map: Optional[torch.Tensor]
|
||||
self.expert_map: torch.Tensor | None
|
||||
local_num_experts, expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
@@ -1187,12 +1187,12 @@ class FusedMoE(CustomOp):
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method: QuantizeMethodBase | None = None
|
||||
quant_method = (
|
||||
UnquantizedFusedMoEMethod(moe)
|
||||
if quant_config is None
|
||||
@@ -1238,8 +1238,8 @@ class FusedMoE(CustomOp):
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
|
||||
if self.use_dp_chunking:
|
||||
states_shape: tuple[int, ...]
|
||||
@@ -1262,7 +1262,7 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
def shared_experts(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -1534,7 +1534,7 @@ class FusedMoE(CustomOp):
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
if self.quant_config and self.quant_config.get_name() == "mxfp4":
|
||||
# (FIXME) for gpt-oss all experts are combined
|
||||
if "bias" in weight_name:
|
||||
@@ -1851,21 +1851,21 @@ class FusedMoE(CustomOp):
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
enable_eplb: bool = False,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
global_num_experts: Optional[int] = None,
|
||||
zero_expert_num: Optional[int] = None,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
global_num_experts: int | None = None,
|
||||
zero_expert_num: int | None = None,
|
||||
zero_expert_type: str | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
@@ -2006,7 +2006,7 @@ class FusedMoE(CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
og_hidden_states = hidden_states.shape[-1]
|
||||
if self.hidden_size != og_hidden_states:
|
||||
hidden_states = F.pad(
|
||||
@@ -2047,14 +2047,14 @@ class FusedMoE(CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(hidden_states, router_logits)
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||
@@ -2200,7 +2200,7 @@ class FusedMoE(CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Callable, Optional, Union, final
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
@@ -81,7 +82,7 @@ class ExpertTokensMetadata:
|
||||
"""
|
||||
|
||||
expert_num_tokens: torch.Tensor
|
||||
expert_num_tokens_cpu: Optional[torch.Tensor]
|
||||
expert_num_tokens_cpu: torch.Tensor | None
|
||||
|
||||
@staticmethod
|
||||
def make_from_list(
|
||||
@@ -104,7 +105,7 @@ class TopKWeightAndReduce(ABC):
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
output: Optional[torch.Tensor],
|
||||
output: torch.Tensor | None,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
@@ -132,10 +133,10 @@ class TopKWeightAndReduce(ABC):
|
||||
#
|
||||
PrepareResultType = tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor | None,
|
||||
ExpertTokensMetadata | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
@@ -155,7 +156,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> PrepareResultType:
|
||||
@@ -195,10 +196,10 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
|
||||
) -> tuple[Callable, ReceiverType] | ReceiverType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
@@ -270,7 +271,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> Union[tuple[Callable, Callable], Callable]:
|
||||
) -> tuple[Callable, Callable] | Callable:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output but do not wait for results from other workers.
|
||||
@@ -314,7 +315,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
@@ -324,7 +325,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
@@ -423,11 +424,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
def quant_dtype(self) -> torch.dtype | None:
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
def block_shape(self) -> list[int] | None:
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
@@ -439,51 +440,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||
def a1_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a1_scale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||
def a2_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a2_scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||
def a1_gscale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a1_gscale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||
def a2_gscale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a2_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||
def w1_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w1_scale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||
def w2_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w2_scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||
def w1_zp(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w1_zp
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||
def w2_zp(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w2_zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||
def w1_bias(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w1_bias
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||
def w2_bias(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w2_bias
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||
def g1_alphas(self) -> torch.Tensor | None:
|
||||
return self.quant_config.g1_alphas
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||
def g2_alphas(self) -> torch.Tensor | None:
|
||||
return self.quant_config.g2_alphas
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@@ -517,7 +518,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
@@ -578,12 +579,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -625,8 +626,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
@@ -688,7 +689,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
@@ -741,7 +742,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate temporary and output buffers for the fused experts op.
|
||||
@@ -825,11 +826,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
@staticmethod
|
||||
def _slice_expert_tokens_metadata(
|
||||
num_chunks: int,
|
||||
full_expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
full_expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> Optional[ExpertTokensMetadata]:
|
||||
expert_map: torch.Tensor | None,
|
||||
) -> ExpertTokensMetadata | None:
|
||||
if num_chunks == 1 or full_expert_tokens_meta is None:
|
||||
return full_expert_tokens_meta
|
||||
|
||||
@@ -861,12 +862,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
torch.Tensor | None,
|
||||
ExpertTokensMetadata | None,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
@@ -945,7 +946,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self,
|
||||
in_dtype: torch.dtype,
|
||||
a1q: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a1q_scale: torch.Tensor | None,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@@ -953,9 +954,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
) -> torch.Tensor:
|
||||
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||
a1q, w1, w2, topk_ids
|
||||
@@ -1042,12 +1043,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
that handles DBO, async and shared expert overlap.
|
||||
"""
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
shared_output: torch.Tensor | None = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
@@ -1112,9 +1113,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
of weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -13,7 +12,7 @@ def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -13,14 +12,12 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a1q_scale: torch.Tensor | None,
|
||||
curr_topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
block_m: int,
|
||||
) -> tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
|
||||
]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
@@ -33,7 +30,7 @@ def _moe_permute(
|
||||
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
|
||||
)
|
||||
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
inv_perm: torch.Tensor | None = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
@@ -53,7 +50,7 @@ def _moe_permute(
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: Optional[torch.Tensor],
|
||||
inv_perm: torch.Tensor | None,
|
||||
topk_weight: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
@@ -73,17 +70,15 @@ def _moe_unpermute_and_reduce(
|
||||
|
||||
def moe_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a1q_scale: torch.Tensor | None,
|
||||
topk_ids: torch.Tensor,
|
||||
n_expert: int,
|
||||
n_local_expert: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
align_block_size: int | None = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
permuted_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
|
||||
]:
|
||||
permuted_hidden_states: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function expands and permutes activation to gather uncontinuous tokens
|
||||
for each expert.
|
||||
@@ -198,7 +193,7 @@ def moe_unpermute(
|
||||
permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
inv_permuted_idx: torch.Tensor,
|
||||
expert_first_token_offset: Optional[torch.Tensor] = None,
|
||||
expert_first_token_offset: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
This function expands and permutes activation to gathering uncontinuous
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import pplx_kernels as pplx
|
||||
import torch
|
||||
@@ -24,9 +24,9 @@ def pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens: int,
|
||||
hidden_dim: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
block_shape: list[int] | None,
|
||||
):
|
||||
# All pplx byte sizes must be 16-byte aligned.
|
||||
align = 16
|
||||
@@ -82,10 +82,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.uint32
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
@@ -103,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
@@ -148,7 +148,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
|
||||
)
|
||||
|
||||
orig_a_scale_block_shape: Optional[int] = None
|
||||
orig_a_scale_block_shape: int | None = None
|
||||
|
||||
if a1q_scale is not None:
|
||||
scalar_scales = a1q_scale.numel() == 1
|
||||
@@ -184,7 +184,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
device=device,
|
||||
)
|
||||
|
||||
expert_x_scale: Optional[torch.Tensor] = None
|
||||
expert_x_scale: torch.Tensor | None = None
|
||||
if a1q.dtype.itemsize == 1:
|
||||
if quant_config.is_per_act_token:
|
||||
# (M x 1) -> (E x M x K)
|
||||
@@ -212,7 +212,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# This argument is optional, defaults to indices.size(0)
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
bound_m: torch.Tensor | None = None
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
@@ -252,8 +252,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
expert_x: torch.Tensor,
|
||||
expert_x_scale: Optional[torch.Tensor],
|
||||
orig_a_scale_block_shape: Optional[int],
|
||||
expert_x_scale: torch.Tensor | None,
|
||||
orig_a_scale_block_shape: int | None,
|
||||
) -> mk.PrepareResultType:
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||
@@ -271,7 +271,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
@@ -302,7 +302,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# This argument is optional
|
||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
bound_m: torch.Tensor | None = None
|
||||
|
||||
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
|
||||
# num_tokens = output.size(0) # M
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,10 +17,10 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
@@ -36,7 +35,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import IntEnum
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -53,13 +52,13 @@ def rocm_aiter_asm_moe_tkw1_impl(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: Optional[torch.Tensor] = None,
|
||||
fc2_scale: Optional[torch.Tensor] = None,
|
||||
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType
|
||||
@@ -90,13 +89,13 @@ def rocm_aiter_asm_moe_tkw1_fake(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: Optional[torch.Tensor] = None,
|
||||
fc2_scale: Optional[torch.Tensor] = None,
|
||||
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
@@ -206,14 +205,14 @@ def rocm_aiter_fused_moe_impl(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
quant_method: int = QuantMethod.NO.value,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
@@ -244,14 +243,14 @@ def rocm_aiter_fused_moe_fake(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
quant_method: int = QuantMethod.NO.value,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -300,7 +299,7 @@ def rocm_aiter_grouped_topk(
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
token = hidden_states.shape[0]
|
||||
device = hidden_states.device
|
||||
@@ -342,8 +341,8 @@ def rocm_aiter_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
@@ -10,7 +10,7 @@ like uniform random routing.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -24,7 +24,7 @@ class RoutingStrategy(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route tokens to experts.
|
||||
@@ -89,7 +89,7 @@ class DistributionBasedRouting(RoutingStrategy):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Randomly select experts for each token using the specified distribution.
|
||||
@@ -269,7 +269,7 @@ class RoutingSimulator:
|
||||
router_logits: torch.Tensor,
|
||||
strategy_name: str,
|
||||
top_k: int,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Simulate token-to-expert routing using the specified strategy.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,7 +17,7 @@ class SharedFusedMoE(FusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: Optional[torch.nn.Module],
|
||||
shared_experts: torch.nn.Module | None,
|
||||
use_overlapped: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -35,7 +34,7 @@ class SharedFusedMoE(FusedMoE):
|
||||
)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> Optional[torch.nn.Module]:
|
||||
def shared_experts(self) -> torch.nn.Module | None:
|
||||
return self._shared_experts if self.use_overlapped else None
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -29,7 +28,7 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: Optional[torch.Tensor],
|
||||
output: torch.Tensor | None,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
@@ -52,7 +51,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: Optional[torch.Tensor],
|
||||
output: torch.Tensor | None,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
@@ -84,7 +83,7 @@ class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: Optional[torch.Tensor],
|
||||
output: torch.Tensor | None,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
@@ -133,7 +132,7 @@ class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: Optional[torch.Tensor],
|
||||
output: torch.Tensor | None,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -89,7 +88,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
@@ -128,12 +127,12 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
use_deep_gemm = self.allow_deep_gemm and (
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -58,7 +57,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
@@ -100,12 +99,12 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
topk = topk_ids.size(-1)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from math import prod
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -60,7 +59,7 @@ def _count_expert_num_tokens(
|
||||
|
||||
|
||||
def count_expert_num_tokens(
|
||||
topk_ids: torch.Tensor, num_local_experts: int, expert_map: Optional[torch.Tensor]
|
||||
topk_ids: torch.Tensor, num_local_experts: int, expert_map: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Count the number to tokens assigned to each expert.
|
||||
@@ -112,7 +111,7 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
|
||||
def _nvfp4_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return flashinfer_fp4_quantize(
|
||||
@@ -122,9 +121,9 @@ def _nvfp4_quantize(
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
per_act_token: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
@@ -148,9 +147,9 @@ def _fp8_quantize(
|
||||
|
||||
def _int8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
per_act_token: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform int8 quantization on the inputs. If a block_shape
|
||||
@@ -175,9 +174,9 @@ def _int8_quantize(
|
||||
|
||||
def _mxfp4_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
assert block_shape is None
|
||||
# TODO: native mxfp4 is currently not integrated in vllm,
|
||||
@@ -191,9 +190,9 @@ def _mxfp4_quantize(
|
||||
|
||||
def _mxfp8_e4m3_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert A_scale is None
|
||||
assert not per_act_token_quant
|
||||
@@ -203,9 +202,9 @@ def _mxfp8_e4m3_quantize(
|
||||
|
||||
def _mxfp6_e3m2_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
assert block_shape is None
|
||||
|
||||
@@ -220,9 +219,9 @@ def _mxfp6_e3m2_quantize(
|
||||
|
||||
def _mxfp6_e2m3_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
assert block_shape is None
|
||||
|
||||
@@ -237,12 +236,12 @@ def _mxfp6_e2m3_quantize(
|
||||
|
||||
def moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
quant_dtype: Union[None, torch.dtype, str],
|
||||
A_scale: torch.Tensor | None,
|
||||
quant_dtype: None | torch.dtype | str,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
is_fp4_scale_swizzled: bool = True,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.int8:
|
||||
@@ -273,7 +272,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
def normalize_scales_shape(scales: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
scales = scales.view(1, 1)
|
||||
@@ -283,9 +282,9 @@ def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Ten
|
||||
|
||||
|
||||
def normalize_batched_scales_shape(
|
||||
scales: Optional[torch.Tensor],
|
||||
scales: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None and scales.ndim < 3:
|
||||
if scales.numel() == 1:
|
||||
scales = scales.view(1)
|
||||
@@ -300,9 +299,9 @@ def normalize_batched_scales_shape(
|
||||
|
||||
def _validate_scale_shape(
|
||||
a: torch.Tensor,
|
||||
a_scale: Optional[torch.Tensor],
|
||||
a_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
block_shape: list[int] | None,
|
||||
) -> None:
|
||||
if a_scale is None:
|
||||
return
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Custom normalization layers."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -159,9 +157,9 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
var_hidden_size: int | None = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -190,8 +188,8 @@ class RMSNorm(CustomOp):
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
@@ -231,8 +229,8 @@ class RMSNorm(CustomOp):
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
@@ -247,8 +245,8 @@ class RMSNorm(CustomOp):
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
@@ -263,8 +261,8 @@ class RMSNorm(CustomOp):
|
||||
def forward_xpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
@@ -313,8 +311,8 @@ class GemmaRMSNorm(CustomOp):
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
@@ -337,16 +335,16 @@ class GemmaRMSNorm(CustomOp):
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@@ -529,7 +528,7 @@ def lightning_attention(
|
||||
v: torch.Tensor,
|
||||
ed: torch.Tensor,
|
||||
block_size: int = 256,
|
||||
kv_history: Optional[torch.Tensor] = None,
|
||||
kv_history: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply lightning attention algorithm
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
@@ -187,7 +187,7 @@ class LinearMethodBase(QuantizeMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
@@ -252,7 +252,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
@@ -276,8 +276,8 @@ class LinearBase(CustomOp):
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
@@ -295,7 +295,7 @@ class LinearBase(CustomOp):
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
@@ -333,8 +333,8 @@ class ReplicatedLinear(LinearBase):
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
@@ -409,7 +409,7 @@ class ReplicatedLinear(LinearBase):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -461,9 +461,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
output_sizes: list[int] | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
@@ -574,7 +574,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
@@ -633,8 +633,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
@@ -662,7 +662,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None,
|
||||
loaded_shard_id: int | None = None,
|
||||
):
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
@@ -838,7 +838,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None,
|
||||
loaded_shard_id: int | None = None,
|
||||
):
|
||||
if loaded_shard_id is None:
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
@@ -914,11 +914,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
@@ -1027,7 +1027,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
@@ -1071,7 +1071,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
@@ -1296,9 +1296,9 @@ class RowParallelLinear(LinearBase):
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
@@ -1405,7 +1405,7 @@ class RowParallelLinear(LinearBase):
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that compute logits from hidden_stats."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
@@ -28,10 +26,10 @@ class LogitsProcessor(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
org_vocab_size: Optional[int] = None,
|
||||
org_vocab_size: int | None = None,
|
||||
scale: float = 1.0,
|
||||
logits_as_input: bool = False,
|
||||
soft_cap: Optional[float] = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@@ -53,8 +51,8 @@ class LogitsProcessor(CustomOp):
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
@@ -88,8 +86,8 @@ class LogitsProcessor(CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
embedding_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -87,8 +87,8 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert residual is None, "RMSNorm does not support residual connection."
|
||||
return self._forward(x)
|
||||
|
||||
@@ -102,7 +102,7 @@ class MiniMaxText01LinearKernel:
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
layer_idx: Optional[int] = None,
|
||||
layer_idx: int | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
slope_rate = slope_rate.to(torch.float32)
|
||||
@@ -154,9 +154,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
max_position: int,
|
||||
block_size: int,
|
||||
num_hidden_layer: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
layer_idx: int = 0,
|
||||
linear_layer_idx: int = 0,
|
||||
prefix: str = "linear_attn",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -68,8 +68,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -410,7 +410,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||
def _time_proj_bias(self) -> torch.Tensor | None:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
return None
|
||||
@@ -423,8 +423,8 @@ class PrefillDecodeSplit(NamedTuple):
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
query_start_loc_p: Optional[torch.Tensor]
|
||||
has_initial_states_p: Optional[torch.Tensor]
|
||||
query_start_loc_p: torch.Tensor | None
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
@@ -432,7 +432,7 @@ def split_batch_to_prefill_and_decode(
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
has_initial_states: Optional[torch.Tensor],
|
||||
has_initial_states: torch.Tensor | None,
|
||||
num_prefill_tokens: int,
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -138,7 +138,7 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
@@ -244,9 +244,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -474,7 +474,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -482,7 +482,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
@@ -495,7 +495,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
@@ -904,7 +904,7 @@ def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
@@ -915,7 +915,7 @@ def mamba_mixer2_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,7 +13,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires testing
|
||||
@@ -26,7 +25,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def mamba1_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
@@ -37,7 +36,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
@@ -48,7 +47,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def _mamba_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
@@ -63,7 +62,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def short_conv_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
@@ -72,7 +71,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def gated_delta_net_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -469,17 +468,17 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Union[torch.Tensor, None],
|
||||
bias: torch.Tensor | None,
|
||||
conv_states: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
cache_indices: torch.Tensor | None = None,
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
|
||||
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
num_computed_tokens: Optional[torch.Tensor] = None,
|
||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
num_computed_tokens: torch.Tensor | None = None,
|
||||
block_size_to_align=0,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
@@ -1071,15 +1070,15 @@ def causal_conv1d_update(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Union[bool, str, None] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
activation: bool | str | None = None,
|
||||
conv_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
query_start_loc: torch.Tensor | None = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -38,8 +38,8 @@ class ShortConv(MambaBase, CustomOp):
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,14 +18,14 @@ class MLAModules:
|
||||
kv_b_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
fused_qkv_a_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_b_proj: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
indexer: Optional[torch.nn.Module]
|
||||
fused_qkv_a_proj: torch.nn.Module | None
|
||||
kv_a_proj_with_mqa: torch.nn.Module | None
|
||||
q_a_layernorm: torch.nn.Module | None
|
||||
q_b_proj: torch.nn.Module | None
|
||||
q_proj: torch.nn.Module | None
|
||||
indexer: torch.nn.Module | None
|
||||
is_sparse: bool
|
||||
topk_indices_buffer: Optional[torch.Tensor]
|
||||
topk_indices_buffer: torch.Tensor | None
|
||||
|
||||
|
||||
@CustomOp.register("multi_head_latent_attention")
|
||||
@@ -55,11 +54,11 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
mla_modules: MLAModules,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Set
|
||||
from collections.abc import Callable, Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from itertools import groupby
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -24,8 +24,8 @@ from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PoolingFn = Callable[
|
||||
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
[torch.Tensor | list[torch.Tensor], PoolingMetadata],
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
]
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
@@ -90,7 +90,7 @@ class Pooler(nn.Module, ABC):
|
||||
@staticmethod
|
||||
def for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: Optional[ClassifierFn],
|
||||
classifier: ClassifierFn | None,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="classify",
|
||||
@@ -118,14 +118,14 @@ class Pooler(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[list[torch.Tensor], torch.Tensor],
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_prompt_lens(
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return pooling_metadata.prompt_lens
|
||||
@@ -174,7 +174,7 @@ def get_classification_activation_function(config: PretrainedConfig):
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
function_name: Optional[str] = None
|
||||
function_name: str | None = None
|
||||
if (
|
||||
hasattr(config, "sentence_transformers")
|
||||
and "activation_fn" in config.sentence_transformers
|
||||
@@ -223,14 +223,14 @@ class PoolingMethod(nn.Module, ABC):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
return self.forward_all(hidden_states, pooling_cursor)
|
||||
|
||||
@@ -243,7 +243,7 @@ class CLSPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
@@ -259,7 +259,7 @@ class LastPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
@@ -271,7 +271,7 @@ class AllPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with ALL pooling"
|
||||
)
|
||||
@@ -290,7 +290,7 @@ class MeanPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
@@ -405,7 +405,7 @@ class PoolerHead(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
return self.activation(pooled_data)
|
||||
@@ -418,14 +418,14 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
# Load ST projector if available
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector: Optional[nn.Module] = (
|
||||
self.projector: nn.Module | None = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
if isinstance(pooled_data, list):
|
||||
@@ -480,7 +480,7 @@ class RewardPoolerHead(PoolerHead):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
if isinstance(pooled_data, list):
|
||||
@@ -541,7 +541,7 @@ class SimplePooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
@@ -560,9 +560,9 @@ class StepPooler(Pooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
@@ -593,7 +593,7 @@ class StepPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
@@ -621,8 +621,8 @@ class ClassifierPooler(Pooler):
|
||||
def __init__(
|
||||
self,
|
||||
pooling: PoolingFn,
|
||||
classifier: Optional[ClassifierFn],
|
||||
act_fn: Optional[PoolerActivation] = None,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -631,7 +631,7 @@ class ClassifierPooler(Pooler):
|
||||
self.pooling = pooling
|
||||
self.classifier = classifier
|
||||
self.act_fn = act_fn or PoolerClassify()
|
||||
self.logit_bias: Optional[float] = (
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
@@ -641,7 +641,7 @@ class ClassifierPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
@@ -695,7 +695,7 @@ class DispatchPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -46,8 +46,8 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
sym: bool = True,
|
||||
packing_format: str = "auto_round:auto_gptq",
|
||||
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
|
||||
extra_config: Optional[dict[str, Any]] = None,
|
||||
block_name_to_quantize: str | list[str] | None = None,
|
||||
extra_config: dict[str, Any] | None = None,
|
||||
data_type: str = "int",
|
||||
backend: str = "auto",
|
||||
) -> None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -34,7 +34,7 @@ class AWQConfig(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
@@ -88,7 +88,7 @@ class AWQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
||||
) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
@@ -227,7 +227,7 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@@ -70,7 +71,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
modules_to_not_convert: list[str] | None,
|
||||
full_config: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -140,7 +141,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional[QuantizationMethods]:
|
||||
) -> QuantizationMethods | None:
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
|
||||
@@ -360,7 +361,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_awq_marlin_linear(
|
||||
input=x,
|
||||
@@ -555,7 +556,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
@@ -566,21 +567,21 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -105,7 +105,7 @@ class QuantizationConfig(ABC):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional[QuantizationMethods]:
|
||||
) -> QuantizationMethods | None:
|
||||
"""
|
||||
Detects if this quantization method can support a given checkpoint
|
||||
format by overriding the user specified quantization method --
|
||||
@@ -135,7 +135,7 @@ class QuantizationConfig(ABC):
|
||||
@abstractmethod
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
) -> QuantizeMethodBase | None:
|
||||
"""Get the quantize method to use for the quantized layer.
|
||||
|
||||
Args:
|
||||
@@ -147,7 +147,7 @@ class QuantizationConfig(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper( # noqa: B027
|
||||
|
||||
@@ -45,10 +45,10 @@ class BitBLASConfig(QuantizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: Optional[int],
|
||||
desc_act: Optional[bool],
|
||||
is_sym: Optional[bool],
|
||||
quant_method: Optional[str],
|
||||
group_size: int | None,
|
||||
desc_act: bool | None,
|
||||
is_sym: bool | None,
|
||||
quant_method: str | None,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -160,7 +160,7 @@ class BitBLASConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional[QuantizationMethods]:
|
||||
) -> QuantizationMethods | None:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
||||
is_bitblas_format = hf_quant_cfg.get(
|
||||
@@ -469,7 +469,7 @@ class BitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -41,7 +42,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[list[str]] = None,
|
||||
llm_int8_skip_modules: list[str] | None = None,
|
||||
llm_int8_threshold: float = 6.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -138,7 +139,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
|
||||
) -> Union["LinearMethodBase", "BitsAndBytesMoEMethod"] | None:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
@@ -268,7 +269,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.load_in_8bit:
|
||||
return self._apply_8bit_weight(layer, x, bias)
|
||||
@@ -279,7 +280,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import MatmulLtState, matmul
|
||||
@@ -359,7 +360,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
@@ -489,7 +490,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
@@ -500,21 +501,21 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert self.fused_experts is None
|
||||
|
||||
@@ -71,7 +71,7 @@ logger = init_logger(__name__)
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, QuantizationArgs] | None]
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
@@ -82,9 +82,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: list[str],
|
||||
kv_cache_scheme: Optional[dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
transform_config: Optional[dict[str, Any]] = None,
|
||||
kv_cache_scheme: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
transform_config: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@@ -524,7 +524,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self,
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
format: Optional[str] = None,
|
||||
format: str | None = None,
|
||||
) -> "CompressedTensorsScheme":
|
||||
# use the per-layer format if defined, otherwise, use global format
|
||||
format = format if format is not None else self.quant_format
|
||||
@@ -631,7 +631,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
|
||||
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
||||
self, layer: torch.nn.Module, layer_name: str | None = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
@@ -674,7 +674,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_targets = self.sparsity_scheme_map.keys() - set(
|
||||
self.sparsity_ignore_list
|
||||
)
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
sparsity_scheme: SparsityCompressionConfig | None = None
|
||||
with suppress(ValueError):
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
@@ -723,7 +723,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
@@ -751,9 +751,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: Optional[QuantizationArgs],
|
||||
input_quant: Optional[QuantizationArgs],
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None,
|
||||
weight_quant: QuantizationArgs | None,
|
||||
input_quant: QuantizationArgs | None,
|
||||
sparsity_scheme: SparsityCompressionConfig | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the layer is supported by the Cutlass 2:4 Kernel
|
||||
@@ -853,7 +853,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
@@ -878,7 +878,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: dict[str, Any] | None):
|
||||
"""
|
||||
Validator for the kv cache scheme. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
@@ -372,7 +372,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
(layer.w2_input_global_scale), requires_grad=False
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
elif not self.allow_flashinfer:
|
||||
@@ -399,7 +399,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
|
||||
@@ -420,21 +420,21 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
|
||||
@@ -913,7 +913,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight_scale
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
@@ -997,7 +997,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
|
||||
@@ -1022,21 +1022,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet."
|
||||
@@ -1280,7 +1280,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return int8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
@@ -1297,21 +1297,21 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@@ -1604,7 +1604,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
@@ -1615,21 +1615,21 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@@ -1856,7 +1856,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
assert self.num_bits == 4 or self.num_bits == 8
|
||||
config_builder = (
|
||||
int4_w4a16_moe_quant_config
|
||||
@@ -1880,21 +1880,21 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
@@ -2092,7 +2092,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def _pack_matrix(
|
||||
int4_as_int8_2d: torch.Tensor,
|
||||
scales_2d: torch.Tensor,
|
||||
bias_1d: Optional[torch.Tensor],
|
||||
bias_1d: torch.Tensor | None,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
) -> torch.Tensor:
|
||||
@@ -2192,7 +2192,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# CPU dynamic 4-bit MoE path does not use modular kernels or
|
||||
# fused_experts; quant config is not needed.
|
||||
return None
|
||||
@@ -2205,20 +2205,20 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||
assert activation in ("silu", "swigluoai", "swiglu"), (
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
@@ -42,9 +43,9 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self,
|
||||
quantized: bool = False,
|
||||
weight_quant: Optional[QuantizationArgs] = None,
|
||||
input_quant: Optional[QuantizationArgs] = None,
|
||||
model_compression_config: Optional[dict[str, Any]] = None,
|
||||
weight_quant: QuantizationArgs | None = None,
|
||||
input_quant: QuantizationArgs | None = None,
|
||||
model_compression_config: dict[str, Any] | None = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
@@ -247,7 +248,7 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -33,7 +32,7 @@ class CompressedTensorsScheme(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@@ -30,7 +30,7 @@ W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None):
|
||||
def __init__(self, strategy: str, num_bits: int, group_size: int | None = None):
|
||||
self.strategy = strategy
|
||||
self.group_size = group_size
|
||||
self.tile_size = 16
|
||||
@@ -143,7 +143,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
layer.workspace = workspace
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.weight_packed
|
||||
meta = layer.meta
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -110,7 +110,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -156,7 +156,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
out = run_nvfp4_emulations(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
@@ -41,9 +41,9 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None,
|
||||
group_size: int | None = None,
|
||||
symmetric: bool | None = True,
|
||||
actorder: ActivationOrdering | None = None,
|
||||
):
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
@@ -178,6 +178,6 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
@@ -36,7 +36,7 @@ class CompressedTensorsW4A8Int(CompressedTensorsScheme):
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
group_size: int | None = None,
|
||||
is_static_input_scheme: bool = False,
|
||||
input_symmetric: bool = True,
|
||||
):
|
||||
@@ -148,6 +148,6 @@ class CompressedTensorsW4A8Int(CompressedTensorsScheme):
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@@ -125,7 +125,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
@@ -179,7 +179,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@@ -120,6 +120,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
@@ -42,9 +42,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None,
|
||||
group_size: int | None = None,
|
||||
symmetric: bool | None = True,
|
||||
actorder: ActivationOrdering | None = None,
|
||||
):
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
@@ -214,6 +214,6 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from itertools import accumulate
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (
|
||||
@@ -38,7 +37,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
def from_schemes(
|
||||
cls,
|
||||
quant_method: LinearMethodBase,
|
||||
quant_scheme: Optional[CompressedTensorsScheme],
|
||||
quant_scheme: CompressedTensorsScheme | None,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple],
|
||||
) -> "CompressedTensorsLinearTransformMethod":
|
||||
@@ -66,8 +65,8 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
self.input_tfms = input_tfms
|
||||
self.output_tfms = output_tfms
|
||||
|
||||
self.input_transform: Optional[HadamardTransform] = None
|
||||
self.output_transform: Optional[HadamardTransform] = None
|
||||
self.input_transform: HadamardTransform | None = None
|
||||
self.output_transform: HadamardTransform | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -151,7 +150,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.input_transform is not None:
|
||||
x = self.input_transform(x)
|
||||
@@ -194,7 +193,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
def get_linear_transform_schemes(
|
||||
layer: torch.nn.Module,
|
||||
layer_name: str,
|
||||
transform_config: Optional[TransformConfig],
|
||||
transform_config: TransformConfig | None,
|
||||
packed_modules_mapping: dict[str, list[str]],
|
||||
) -> tuple[
|
||||
dict[int, TransformTuple], dict[int, TransformTuple]
|
||||
@@ -226,7 +225,7 @@ def get_linear_transform_schemes(
|
||||
|
||||
|
||||
def get_schemes_args(
|
||||
transform_config: Optional[TransformConfig],
|
||||
transform_config: TransformConfig | None,
|
||||
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
|
||||
if transform_config is None:
|
||||
return
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Hashable
|
||||
from typing import Callable
|
||||
from collections.abc import Callable, Hashable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,7 +16,7 @@ __all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
|
||||
|
||||
|
||||
def is_qutlass_fp4_scheme(
|
||||
quant_scheme: Optional[CompressedTensorsScheme],
|
||||
quant_scheme: CompressedTensorsScheme | None,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
) -> bool:
|
||||
return (
|
||||
@@ -60,6 +59,6 @@ class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -145,7 +144,7 @@ def triton_scaled_mm(
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
|
||||
import regex as re
|
||||
from compressed_tensors import CompressionFormat
|
||||
@@ -21,7 +20,7 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
layer_name: str | None,
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
@@ -84,7 +83,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
|
||||
|
||||
def find_matched_target(
|
||||
layer_name: Optional[str],
|
||||
layer_name: str | None,
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
@@ -134,7 +133,7 @@ def find_matched_target(
|
||||
|
||||
def _find_first_match(
|
||||
value: str, targets: Iterable[str], check_contains: bool = False
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""
|
||||
Returns first element of target that matches value either
|
||||
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||
@@ -176,7 +175,7 @@ def _match_fused_layer(
|
||||
layer_name: str,
|
||||
target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]],
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
@@ -205,7 +204,7 @@ def _match_fused_layer(
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: list[Optional[str]] = []
|
||||
unfused_matches: list[str | None] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
|
||||
@@ -140,7 +140,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
y = weight.ds_dequantize()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -129,7 +130,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return int8_w8a16_moe_quant_config(
|
||||
w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None
|
||||
)
|
||||
@@ -142,21 +143,21 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -171,7 +171,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -173,8 +174,8 @@ class Fp8Config(QuantizationConfig):
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
weight_block_size: Optional[list[int]] = None,
|
||||
ignored_layers: list[str] | None = None,
|
||||
weight_block_size: list[int] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -298,7 +299,7 @@ class Fp8Config(QuantizationConfig):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
@@ -530,7 +531,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
@@ -584,12 +585,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
|
||||
self.fused_experts: Optional[mk.FusedMoEModularKernel] = None # type: ignore
|
||||
self.fused_experts: mk.FusedMoEModularKernel | None = None # type: ignore
|
||||
|
||||
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
||||
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
@@ -970,7 +971,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale_inv
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.rocm_aiter_moe_enabled
|
||||
or self.use_marlin
|
||||
@@ -1043,7 +1044,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
|
||||
@@ -1069,21 +1070,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -36,7 +36,7 @@ class FPQuantConfig(QuantizationConfig):
|
||||
forward_dtype: str = "mxfp4",
|
||||
forward_method: str = "abs_max",
|
||||
pseudoquantization: bool = False,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hadamard_group_size = hadamard_group_size
|
||||
@@ -90,7 +90,7 @@ class FPQuantConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[LinearMethodBase]:
|
||||
) -> LinearMethodBase | None:
|
||||
if self.modules_to_not_convert is not None and any(
|
||||
prefix.endswith(module) for module in self.modules_to_not_convert
|
||||
):
|
||||
@@ -233,7 +233,7 @@ class FPQuantLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return quantized_forward(
|
||||
x,
|
||||
@@ -381,7 +381,7 @@ def quantized_forward(
|
||||
weight_scales: torch.Tensor,
|
||||
weight_global_scale: torch.Tensor,
|
||||
act_global_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
bias: torch.Tensor | None,
|
||||
forward_hadamard_matrix: torch.Tensor,
|
||||
forward_method: str,
|
||||
forward_dtype: str,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@@ -35,7 +36,7 @@ logger = init_logger(__name__)
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self, unquantized_modules: Optional[list[str]] = None) -> None:
|
||||
def __init__(self, unquantized_modules: list[str] | None = None) -> None:
|
||||
super().__init__()
|
||||
self.unquantized_modules = unquantized_modules or []
|
||||
|
||||
@@ -307,7 +308,7 @@ def _apply_gguf_embedding(
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return torch.embedding(qweight, x)
|
||||
@@ -330,7 +331,7 @@ def _apply_gguf_embedding_fake(
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
|
||||
|
||||
@@ -452,7 +453,7 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
shard_id = layer.qweight.shard_id
|
||||
|
||||
@@ -558,7 +559,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
@@ -569,21 +570,21 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -48,9 +48,9 @@ class GPTQConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
autoround_version: str = "",
|
||||
modules_in_block_to_quantize: Optional[list[str]] = None,
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
@@ -148,7 +148,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]:
|
||||
) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None:
|
||||
if isinstance(layer, FusedMoE):
|
||||
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
@@ -170,7 +170,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: Optional[str] = None):
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
@@ -345,7 +345,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
@@ -71,7 +71,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
quant_method: Optional[str],
|
||||
quant_method: str | None,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -180,7 +180,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional[QuantizationMethods]:
|
||||
) -> QuantizationMethods | None:
|
||||
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
@@ -474,7 +474,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
|
||||
if bias is not None:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -103,9 +104,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
full_config: dict[str, Any],
|
||||
modules_in_block_to_quantize: Optional[list[str]] = None,
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
@@ -211,7 +212,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional[QuantizationMethods]:
|
||||
) -> QuantizationMethods | None:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
@@ -283,7 +284,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: Optional[str] = None):
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
@@ -459,7 +460,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
@@ -714,7 +715,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
@@ -725,21 +726,21 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
|
||||
@@ -114,7 +114,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional[QuantizationMethods]:
|
||||
) -> QuantizationMethods | None:
|
||||
is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24"
|
||||
|
||||
is_valid_user_quant = (
|
||||
@@ -287,7 +287,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B_24
|
||||
meta = layer.B_meta
|
||||
|
||||
@@ -45,7 +45,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
skip_modules: Optional[list[str]] = None,
|
||||
skip_modules: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert group_size == 64, "The only supported HQQ group size is currently 64."
|
||||
@@ -327,7 +327,7 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
workspace = MarlinWorkspace(
|
||||
self.output_size_per_partition,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -30,9 +29,9 @@ class QuantFP8(CustomOp):
|
||||
self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: Optional[int] = None,
|
||||
num_token_padding: int | None = None,
|
||||
column_major_scales: bool = False,
|
||||
use_ue8m0: Optional[bool] = None, # for Torch compile
|
||||
use_ue8m0: bool | None = None, # for Torch compile
|
||||
):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
@@ -64,8 +63,8 @@ class QuantFP8(CustomOp):
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
@@ -96,8 +95,8 @@ class QuantFP8(CustomOp):
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
):
|
||||
if self.is_group_quant:
|
||||
assert scale is None, "Group quantization is always dynamic"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -50,9 +51,9 @@ class IPEXConfig(QuantizationConfig):
|
||||
method: str,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
desc_act: Optional[bool] = None,
|
||||
lm_head_quantized: Optional[bool] = None,
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
desc_act: bool | None = None,
|
||||
lm_head_quantized: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.method = method
|
||||
@@ -122,7 +123,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional[QuantizationMethods]:
|
||||
) -> QuantizationMethods | None:
|
||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
@@ -206,7 +207,7 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
@@ -275,7 +276,7 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = layer.ipex_qlinear(reshaped_x)
|
||||
@@ -299,7 +300,7 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
weight = layer.weight.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
@@ -410,7 +411,7 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional[FusedMoEQuantConfig]:
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
@@ -421,20 +422,20 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,7 +20,7 @@ class MPLinearLayerConfig:
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
out_type: Optional[torch.dtype] = None
|
||||
out_type: torch.dtype | None = None
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
@@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
@@ -39,8 +39,8 @@ class MPLinearKernel(ABC):
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
w_zp_param_name: str | None = None,
|
||||
w_gidx_param_name: str | None = None,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
@@ -62,12 +62,12 @@ class MPLinearKernel(ABC):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(
|
||||
self, layer: torch.nn.Module, name: Optional[str], fn: Callable
|
||||
self, layer: torch.nn.Module, name: str | None, fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
old_param = getattr(layer, name)
|
||||
@@ -83,8 +83,8 @@ class MPLinearKernel(ABC):
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor], # w_gidx
|
||||
torch.Tensor | None, # w_zp,
|
||||
torch.Tensor | None, # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel,
|
||||
@@ -48,7 +46,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig, compute_capability: Optional[int] = None
|
||||
config: MPLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,7 +21,7 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
@@ -87,7 +86,7 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -44,9 +43,9 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
bitblas_quant_config: Optional[QuantizationConfig] = None,
|
||||
w_zp_param_name: str | None = None,
|
||||
w_gidx_param_name: str | None = None,
|
||||
bitblas_quant_config: QuantizationConfig | None = None,
|
||||
):
|
||||
self.quant_config = bitblas_quant_config
|
||||
super().__init__(
|
||||
@@ -57,7 +56,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
b_q_weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: Optional[torch.Tensor] = None,
|
||||
qzeros: torch.Tensor | None = None,
|
||||
):
|
||||
from bitblas.quantization.utils import general_compress
|
||||
|
||||
@@ -82,7 +81,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
# qzeros should be de-quantized to int zeros.
|
||||
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
|
||||
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
|
||||
zeros: Optional[torch.Tensor] = None
|
||||
zeros: torch.Tensor | None = None
|
||||
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
|
||||
if zeros_mode == "original":
|
||||
zeros = intzeros.to(torch.float16).contiguous()
|
||||
@@ -113,7 +112,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
is_bitblas_installed = True
|
||||
|
||||
try:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from importlib.util import find_spec
|
||||
from typing import Final, Optional
|
||||
from typing import Final
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,7 +26,7 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||
error_msg = (
|
||||
f"Weight type ({c.weight_type}) not supported by "
|
||||
@@ -76,7 +76,7 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,7 +25,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
@@ -95,7 +94,7 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,7 +19,7 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
@@ -95,7 +94,7 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,7 +24,7 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
@@ -137,7 +136,7 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user