Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CUTLASS based Fused MoE kernels."""
from typing import Callable, Optional
from collections.abc import Callable
import torch
@@ -35,23 +35,23 @@ def run_cutlass_moe_fp8(
topk_ids: torch.Tensor,
activation_callable: Callable,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_num_tokens: torch.Tensor | None,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool,
topk_weights: Optional[torch.Tensor],
topk_weights: torch.Tensor | None,
):
a1q = hidden_states
@@ -249,7 +249,7 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: Optional[torch.dtype],
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
@@ -278,12 +278,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
@@ -331,7 +331,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
@@ -377,7 +377,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
@@ -390,7 +390,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
out_dtype: torch.dtype | None,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
@@ -435,7 +435,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
@@ -457,7 +457,7 @@ def cutlass_moe_fp8(
c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
activation: str = "silu",
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor:
@@ -768,7 +768,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
@@ -793,12 +793,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], # unused
a2_scale: Optional[torch.Tensor], # unused
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, # unused
a2_scale: torch.Tensor | None, # unused
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
@@ -839,7 +839,7 @@ def cutlass_moe_fp4(
n: int,
k: int,
e: int,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
assert expert_map is None, (
@@ -896,7 +896,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
inplace: bool,
activation: str,
apply_router_weight_on_input: bool,
expert_map: Optional[torch.Tensor],
expert_map: torch.Tensor | None,
) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0