Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,6 @@ Test deepep dispatch-combine logic
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
@@ -90,7 +89,7 @@ class TestConfig:
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: Optional[torch.Tensor]
|
||||
rank_token_scales: torch.Tensor | None
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
@@ -128,12 +127,12 @@ def make_modular_kernel(
|
||||
dp_size: int,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
use_fp8_dispatch: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
ht_args: Optional[DeepEPHTArgs] = None
|
||||
ll_args: Optional[DeepEPLLArgs] = None
|
||||
ht_args: DeepEPHTArgs | None = None
|
||||
ll_args: DeepEPLLArgs | None = None
|
||||
|
||||
if low_latency_mode:
|
||||
ll_args = DeepEPLLArgs(
|
||||
@@ -148,16 +147,14 @@ def make_modular_kernel(
|
||||
)
|
||||
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
|
||||
|
||||
a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = (
|
||||
make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=None,
|
||||
deepep_ht_args=ht_args,
|
||||
deepep_ll_args=ll_args,
|
||||
)
|
||||
a2a: DeepEPHTPrepareAndFinalize | DeepEPLLPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=None,
|
||||
deepep_ht_args=ht_args,
|
||||
deepep_ll_args=ll_args,
|
||||
)
|
||||
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
@@ -184,8 +181,8 @@ def deep_ep_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
@@ -281,8 +278,8 @@ def torch_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
@@ -340,8 +337,8 @@ def _deep_ep_moe(
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user