Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Callable, Optional, Union, final
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
@@ -81,7 +82,7 @@ class ExpertTokensMetadata:
|
||||
"""
|
||||
|
||||
expert_num_tokens: torch.Tensor
|
||||
expert_num_tokens_cpu: Optional[torch.Tensor]
|
||||
expert_num_tokens_cpu: torch.Tensor | None
|
||||
|
||||
@staticmethod
|
||||
def make_from_list(
|
||||
@@ -104,7 +105,7 @@ class TopKWeightAndReduce(ABC):
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
output: Optional[torch.Tensor],
|
||||
output: torch.Tensor | None,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
@@ -132,10 +133,10 @@ class TopKWeightAndReduce(ABC):
|
||||
#
|
||||
PrepareResultType = tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor | None,
|
||||
ExpertTokensMetadata | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
@@ -155,7 +156,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> PrepareResultType:
|
||||
@@ -195,10 +196,10 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
|
||||
) -> tuple[Callable, ReceiverType] | ReceiverType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
@@ -270,7 +271,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> Union[tuple[Callable, Callable], Callable]:
|
||||
) -> tuple[Callable, Callable] | Callable:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output but do not wait for results from other workers.
|
||||
@@ -314,7 +315,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
@@ -324,7 +325,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
@@ -423,11 +424,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
def quant_dtype(self) -> torch.dtype | None:
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
def block_shape(self) -> list[int] | None:
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
@@ -439,51 +440,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||
def a1_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a1_scale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||
def a2_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a2_scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||
def a1_gscale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a1_gscale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||
def a2_gscale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.a2_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||
def w1_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w1_scale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||
def w2_scale(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w2_scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||
def w1_zp(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w1_zp
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||
def w2_zp(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w2_zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||
def w1_bias(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w1_bias
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||
def w2_bias(self) -> torch.Tensor | None:
|
||||
return self.quant_config.w2_bias
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||
def g1_alphas(self) -> torch.Tensor | None:
|
||||
return self.quant_config.g1_alphas
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||
def g2_alphas(self) -> torch.Tensor | None:
|
||||
return self.quant_config.g2_alphas
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@@ -517,7 +518,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
@@ -578,12 +579,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -625,8 +626,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
@@ -688,7 +689,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
@@ -741,7 +742,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate temporary and output buffers for the fused experts op.
|
||||
@@ -825,11 +826,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
@staticmethod
|
||||
def _slice_expert_tokens_metadata(
|
||||
num_chunks: int,
|
||||
full_expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
full_expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> Optional[ExpertTokensMetadata]:
|
||||
expert_map: torch.Tensor | None,
|
||||
) -> ExpertTokensMetadata | None:
|
||||
if num_chunks == 1 or full_expert_tokens_meta is None:
|
||||
return full_expert_tokens_meta
|
||||
|
||||
@@ -861,12 +862,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
torch.Tensor | None,
|
||||
ExpertTokensMetadata | None,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
@@ -945,7 +946,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self,
|
||||
in_dtype: torch.dtype,
|
||||
a1q: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a1q_scale: torch.Tensor | None,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@@ -953,9 +954,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
) -> torch.Tensor:
|
||||
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||
a1q, w1, w2, topk_ids
|
||||
@@ -1042,12 +1043,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
that handles DBO, async and shared expert overlap.
|
||||
"""
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
shared_output: torch.Tensor | None = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
@@ -1112,9 +1113,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
of weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Reference in New Issue
Block a user