Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,8 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -539,10 +540,10 @@ def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: Optional[torch.Tensor],
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
@@ -555,8 +556,8 @@ def invoke_fused_moe_kernel(
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
B_bias: Optional[torch.Tensor] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
@@ -808,7 +809,7 @@ def zero_experts_compute_triton(
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
||||
def get_config_file_name(
|
||||
E: int, N: int, dtype: Optional[str], block_shape: Optional[list[int]] = None
|
||||
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
|
||||
) -> str:
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
@@ -823,10 +824,10 @@ def get_config_file_name(
|
||||
def get_moe_configs(
|
||||
E: int,
|
||||
N: int,
|
||||
dtype: Optional[str],
|
||||
block_n: Optional[int] = None,
|
||||
block_k: Optional[int] = None,
|
||||
) -> Optional[dict[int, Any]]:
|
||||
dtype: str | None,
|
||||
block_n: int | None = None,
|
||||
block_k: int | None = None,
|
||||
) -> dict[int, Any] | None:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
@@ -965,8 +966,8 @@ def get_default_config(
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
dtype: str | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
@@ -1016,9 +1017,9 @@ def try_get_optimal_moe_config(
|
||||
w1_shape: tuple[int, ...],
|
||||
w2_shape: tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
dtype: str | None,
|
||||
M: int,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
from vllm.model_executor.layers.fused_moe import get_config
|
||||
|
||||
@@ -1076,7 +1077,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
@@ -1135,7 +1136,7 @@ def grouped_topk(
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
@@ -1211,7 +1212,7 @@ def eplb_map_to_physical_and_record(
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the logical expert ids to physical expert ids
|
||||
@@ -1326,19 +1327,19 @@ def inplace_fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1381,19 +1382,19 @@ def inplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -1423,19 +1424,19 @@ def outplace_fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1477,19 +1478,19 @@ def outplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -1534,8 +1535,8 @@ def fused_experts(
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -1625,8 +1626,8 @@ GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
ocp_mx_scheme: Optional[str],
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
ocp_mx_scheme: str | None,
|
||||
) -> None | torch.dtype | str:
|
||||
"""
|
||||
Get the quantization type based on the quantization strategy flags.
|
||||
We don't have a quant_config at this point so we need to work backwards.
|
||||
@@ -1660,19 +1661,19 @@ def fused_experts_impl(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
@@ -1964,7 +1965,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M, topk, max(N // 2, K))
|
||||
workspace2 = (M, topk, max(N, K))
|
||||
@@ -1981,12 +1982,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
@@ -2074,7 +2075,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q_scale: torch.Tensor | None = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2,
|
||||
|
||||
Reference in New Issue
Block a user