Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -27,7 +26,7 @@ class _RMSNormAndQuantOpHelper:
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
quant_op: torch._ops.OpOverload | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
@@ -110,7 +109,7 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
quant_op: torch._ops.OpOverload | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
||||
@@ -483,7 +482,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
).register(self.patterns)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
def is_applicable_for_shape(self, shape: int | None) -> bool:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user