Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CUTLASS based Fused MoE kernels."""
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
@@ -35,23 +35,23 @@ def run_cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
activation_callable: Callable,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_num_tokens: torch.Tensor | None,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_batched_format: bool,
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor | None,
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
@@ -249,7 +249,7 @@ def run_cutlass_moe_fp8(
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -278,12 +278,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
@@ -331,7 +331,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -377,7 +377,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
@@ -390,7 +390,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
num_dispatchers: int,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
out_dtype: torch.dtype | None,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -435,7 +435,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
@@ -457,7 +457,7 @@ def cutlass_moe_fp8(
|
||||
c_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
activation: str = "silu",
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
) -> torch.Tensor:
|
||||
@@ -768,7 +768,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
@@ -793,12 +793,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor], # unused
|
||||
a2_scale: Optional[torch.Tensor], # unused
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None, # unused
|
||||
a2_scale: torch.Tensor | None, # unused
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
|
||||
@@ -839,7 +839,7 @@ def cutlass_moe_fp4(
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert expert_map is None, (
|
||||
@@ -896,7 +896,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
inplace: bool,
|
||||
activation: str,
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
) -> bool:
|
||||
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
|
||||
return N % 128 == 0 and K % 128 == 0
|
||||
|
||||
Reference in New Issue
Block a user