diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index b5fd67c5b..f0ce5b3db 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from typing import Any import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC): def __init__( self, quant_key: QuantKey, - ): + ) -> None: self.quant_key = quant_key self.quant_dtype = quant_key.dtype @@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC): self.silu_and_mul_matcher = MatcherSiluAndMul() - def empty_quant(self, *args, **kwargs): + def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor: kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @abstractmethod - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: raise NotImplementedError @@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): Fusion for SiluMul+Fp8StaticQuant Pattern """ - def __init__(self): + def __init__(self) -> None: super().__init__(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) - def register(self, pm_pass: PatternMatcherPass): + def get_inputs(self) -> list[torch.Tensor]: + scale = self.quant_matcher.inputs()[1] + return [ + *self.silu_and_mul_matcher.inputs(), # input + scale, + ] + + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, scale: torch.Tensor, - ): + ) -> torch.Tensor: result_silu_mul = self.silu_and_mul_matcher(input) result_quant = self.quant_matcher(result_silu_mul, scale) return result_quant[0] @@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): def replacement( input: torch.Tensor, scale: torch.Tensor, - ): + ) -> torch.Tensor: d = input.shape[-1] // 2 output_shape = input.shape[:-1] + (d,) result = torch.empty( @@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): ) return at[1] - inputs = [ - *self.silu_and_mul_matcher.inputs(), # input - self.quant_matcher.inputs()[1], # scale - ] - pattern(*inputs) + inps = self.get_inputs() + pattern(*inps) - register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) + register_replacement(pattern, replacement, inps, fwd_only, pm_pass) class SiluMulNvfp4QuantPattern(ActivationQuantPattern): @@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): Fusion for SiluMul+Nvfp4Quant Pattern """ - def __init__(self): + def __init__(self) -> None: super().__init__(kNvfp4Quant) - def register(self, pm_pass: PatternMatcherPass): + def get_inputs(self) -> list[torch.Tensor]: + result = self.empty_quant(5, 32) + output_scale = empty_i32(128, 4) + input_ = empty_bf16(5, 64) + scale = empty_fp32(1, 1) + return [result, output_scale, input_, scale] + + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( result: torch.Tensor, output_scale: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: result_silu_mul = self.silu_and_mul_matcher(input) at = auto_functionalized( self.QUANT_OP, @@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): output_scale: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: at = auto_functionalized( self.FUSED_OP, result=result, @@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ) return at[1], at[2] - inputs = [ - self.empty_quant(5, 32), # result - empty_i32(128, 4), # output_scale - empty_bf16(5, 64), # input - empty_fp32(1, 1), # scale - ] - - register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) + register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass) class ActivationQuantFusionPass(VllmPatternMatcherPass): @@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass): """ @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( @@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: torch.fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) - def uuid(self): + def uuid(self) -> str: return VllmInductorPass.hash_source( self, ActivationQuantPattern, diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a67d63614..420007131 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from importlib.util import find_spec +from types import ModuleType import torch import torch._inductor.pattern_matcher as pm @@ -33,15 +34,15 @@ if find_spec("flashinfer"): try: import flashinfer.comm as flashinfer_comm - flashinfer_comm = ( + flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef] flashinfer_comm if hasattr(flashinfer_comm, "trtllm_allreduce_fusion") else None ) except ImportError: - flashinfer_comm = None + flashinfer_comm = None # type: ignore[assignment] else: - flashinfer_comm = None + flashinfer_comm = None # type: ignore[assignment] logger = init_logger(__name__) @@ -58,13 +59,13 @@ class BasePattern: class GEMMReduceScatterPattern(BasePattern): - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [mul, mm_weight] - def register(self, pm_pass: PatternMatcherPass): - def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: mm = torch.ops.aten.mm.default(mul, mm_weight) reduce_scatter = torch.ops.vllm.reduce_scatter.default( mm, @@ -74,7 +75,7 @@ class GEMMReduceScatterPattern(BasePattern): ) return reduce_scatter - def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): + def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( mul, mm_weight, @@ -91,17 +92,17 @@ class GEMMReduceScatterPattern(BasePattern): class AllGatherGEMMPattern(BasePattern): - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: x = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [x, weight] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( x: torch.Tensor, weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( x, dim=0, @@ -111,9 +112,7 @@ class AllGatherGEMMPattern(BasePattern): return torch.ops.aten.mm.default(all_gather, weight) - def replacement( - x: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( x, [weight], @@ -128,7 +127,7 @@ class AllGatherGEMMPattern(BasePattern): class ScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) mm_weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) @@ -139,7 +138,7 @@ class ScaledMMReduceScatterPattern(BasePattern): scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) return [input, mm_weight, scale_a, scale_b] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, mat2: torch.Tensor, @@ -196,7 +195,7 @@ class ScaledMMReduceScatterPattern(BasePattern): class AllGatherScaledMMPattern(BasePattern): - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) @@ -211,7 +210,7 @@ class AllGatherScaledMMPattern(BasePattern): return [x, weight, scale_a, scale_b] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -258,7 +257,7 @@ class AllGatherScaledMMPattern(BasePattern): class CutlassScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) mm_weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) @@ -271,7 +270,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern): cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype) return [input, mm_weight, scale_a, scale_b, cutlass_mm_output] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, @@ -331,7 +330,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern): class AllGatherCutlassScaledMMPattern(BasePattern): - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) weight = ( torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) @@ -349,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern): return [x, weight, scale_a, scale_b, output] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -400,7 +399,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern): class AsyncTPPass(VllmPatternMatcherPass): @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) # Enable symmetric memory for the TP process group @@ -445,7 +444,7 @@ class AsyncTPPass(VllmPatternMatcherPass): return compile_range.is_single_size() and compile_range.end % tp_size == 0 @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph): + def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) @@ -512,11 +511,13 @@ if flashinfer_comm is not None: f"max token num {max_token_num} * hidden size {hidden_size} * " f"element size {element_size}" ) - device_capability = current_platform.get_device_capability().to_int() + curr_device = current_platform.get_device_capability() + device_capability = curr_device.to_int() if curr_device is not None else None # Get one shot input size limit for the current world size # for the current device capability max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, {} + device_capability, # type: ignore[arg-type] + {}, ).get(world_size, None) # Use one shot if no max size is specified use_oneshot = ( @@ -606,7 +607,7 @@ class FlashInferFusedAllReduceParams: world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, - ): + ) -> None: self.rank = rank self.world_size = world_size self.use_fp32_lamport = use_fp32_lamport @@ -615,7 +616,7 @@ class FlashInferFusedAllReduceParams: self.fp32_acc = True self.max_token_num = max_token_num - def get_trtllm_fused_allreduce_kwargs(self): + def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]: return { "world_rank": self.rank, "world_size": self.world_size, @@ -639,26 +640,30 @@ class AllReduceRMSNormPattern(BasePattern): dtype: torch.dtype, device: str | None, allreduce_params: FlashInferFusedAllReduceParams, - ): + ) -> None: super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.rmsnorm_matcher = MatcherRMSNorm(epsilon) - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: input, weight = self.rmsnorm_matcher.inputs() # input goes through allreduce first, always 16-bit return [input.to(self.dtype), weight] - def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor): + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: allreduce_output = tensor_model_parallel_all_reduce(input) rms = self.rmsnorm_matcher(allreduce_output, weight) return rms, allreduce_output - def replacement(input: torch.Tensor, weight: torch.Tensor): + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: residual = torch.zeros_like(input) rms_result = torch.empty_like(input) allreduce = auto_functionalized( @@ -694,27 +699,29 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): dtype: torch.dtype, device: str | None, allreduce_params: FlashInferFusedAllReduceParams, - ): + ) -> None: super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: input, residual, weight = self.rmsnorm_matcher.inputs() # input goes through allreduce first, always 16-bit return [residual, input.to(self.dtype), weight] - def register(self, pm_pass: PatternMatcherPass): - def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: allreduce_output = tensor_model_parallel_all_reduce(input) rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) return rms, residual def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -739,8 +746,8 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] pm.register_replacement( - first_return_only(pattern), - first_return_only(replacement), + first_return_only(pattern), # type: ignore[no-untyped-call] + first_return_only(replacement), # type: ignore[no-untyped-call] self.get_inputs(), pm.fwd_only, pm_pass, @@ -761,7 +768,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): dtype: torch.dtype, device: str | None, allreduce_params: FlashInferFusedAllReduceParams, - ): + ) -> None: super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params @@ -769,25 +776,27 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) - def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input, weight = self.rmsnorm_matcher.inputs() - _, scale = self.quant_matcher.inputs() + def get_inputs(self) -> list[torch.Tensor]: + input, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() - # input goes through allreduce first, always 16-bit - return [input.to(self.dtype), weight, scale] + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight, scale] + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(input) rms = self.rmsnorm_matcher(all_reduce, weight) quant, _ = self.quant_matcher(rms, scale) return quant, all_reduce - def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def replacement( + input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: residual = torch.zeros_like(input) result_rms = torch.empty_like(input) result_quant = torch.empty_like(input, dtype=self.quant_dtype) @@ -812,7 +821,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): return allreduce[4], allreduce[1] pm.register_replacement( - pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) @@ -830,7 +839,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): dtype: torch.dtype, device: str | None, allreduce_params: FlashInferFusedAllReduceParams, - ): + ) -> None: super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params @@ -839,20 +848,20 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) - def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input, residual, weight = self.rmsnorm_matcher.inputs() - _, scale = self.quant_matcher.inputs() + def get_inputs(self) -> list[torch.Tensor]: + input, residual, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() - # input goes through allreduce first, always 16-bit - return [residual, input.to(self.dtype), weight, scale] + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight, scale] + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: allreduce_output = tensor_model_parallel_all_reduce(input) rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual) quant, _ = self.quant_matcher(rms, scale) @@ -864,7 +873,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, @@ -886,7 +895,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): return allreduce[4], allreduce[2] pm.register_replacement( - pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) @@ -904,31 +913,31 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): dtype: torch.dtype, device: str | None, allreduce_params: FlashInferFusedAllReduceParams, - ): + ) -> None: super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.rmsnorm_matcher = MatcherRMSNorm(epsilon) - def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) - quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) - input_global_scale = torch.empty( - [1, 1], device=self.device, dtype=torch.float32 - ) - weight = torch.empty([16], device=self.device, dtype=self.dtype) - output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) + def get_inputs(self) -> list[torch.Tensor]: + input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) + quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) + input_global_scale = torch.empty( + [1, 1], device=self.device, dtype=torch.float32 + ) + weight = torch.empty([16], device=self.device, dtype=self.dtype) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) - return [input, quant_result, weight, input_global_scale, output_scale] + return [input, quant_result, weight, input_global_scale, output_scale] + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: all_reduce = tensor_model_parallel_all_reduce(input) rms = self.rmsnorm_matcher(all_reduce, weight) quant_out_tuple = auto_functionalized( @@ -948,7 +957,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: residual = torch.zeros_like(input) result_rms = torch.empty_like(input) allreduce = auto_functionalized( @@ -972,7 +981,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): return allreduce[4], allreduce[1], allreduce[5] pm.register_replacement( - pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) @@ -990,33 +999,33 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): dtype: torch.dtype, device: str | None, allreduce_params: FlashInferFusedAllReduceParams, - ): + ) -> None: super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) - def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + def get_inputs(self) -> list[torch.Tensor]: + input = torch.empty([16, 16], device=self.device, dtype=self.dtype) - residual = torch.empty([16, 16], device=self.device, dtype=self.dtype) - weight = torch.empty([16, 16], device=self.device, dtype=self.dtype) - quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) - input_global_scale = torch.empty( - [1, 1], device=self.device, dtype=torch.float32 - ) - output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) + residual = torch.empty([16, 16], device=self.device, dtype=self.dtype) + weight = torch.empty([16, 16], device=self.device, dtype=self.dtype) + quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) + input_global_scale = torch.empty( + [1, 1], device=self.device, dtype=torch.float32 + ) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) - return [ - quant_result, - residual, - input, - output_scale, - weight, - input_global_scale, - ] + return [ + quant_result, + residual, + input, + output_scale, + weight, + input_global_scale, + ] + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( quant_result: torch.Tensor, residual: torch.Tensor, @@ -1024,7 +1033,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): output_scale: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: allreduce_output = tensor_model_parallel_all_reduce(input) rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) quant_out_tuple = auto_functionalized( @@ -1045,7 +1054,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): output_scale: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1066,12 +1075,12 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): return allreduce[4], allreduce[2], allreduce[5] pm.register_replacement( - pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) class AllReduceFusionPass(VllmPatternMatcherPass): - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.disabled = True self.tp_size = get_tensor_model_parallel_world_size() @@ -1122,7 +1131,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ) self.ipc_handles, workspace_tensor = ( - flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc] tp_rank=rank, tp_size=self.tp_size, max_token_num=self.max_token_num, @@ -1145,7 +1154,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) @enable_fake_mode - def register_patterns(self): + def register_patterns(self) -> None: for epsilon in [1e-5, 1e-6]: AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, @@ -1198,7 +1207,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): return compile_range.end <= self.max_token_num @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph): + def __call__(self, graph: fx.Graph) -> None: if self.disabled: logger.debug("AllReduceFusionPass disabled") return @@ -1206,7 +1215,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) - def __del__(self): + def __del__(self) -> None: if getattr(self, "disabled", True): return if flashinfer_comm is not None: diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index d12110633..e3c6c2f20 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 -def empty_bf16(*args, **kwargs): +def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") -def empty_fp32(*args, **kwargs): +def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") -def empty_i32(*args, **kwargs): +def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") -def empty_i64(*args, **kwargs): +def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor: return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda") @@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple): quant: QuantKey fused_add: bool - def __str__(self): + def __str__(self) -> str: return ( f"FusedQuantKey({self.quant}, with" f"{'' if self.fused_add else 'out'} residual)" @@ -121,7 +121,7 @@ class RMSNormQuantPattern: key: FusedRMSQuantKey, has_col_major_scales: bool = False, is_e8m0: bool = False, - ): + ) -> None: self.epsilon = epsilon self.quant_dtype = key.quant.dtype config = get_current_vllm_config() @@ -141,7 +141,9 @@ class RMSNormQuantPattern: class RMSNormStaticQuantPattern(RMSNormQuantPattern): - def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + def __init__( + self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True + ) -> None: fused_key = FusedRMSQuantKey( fused_add=False, quant=QuantKey( @@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ) super().__init__(epsilon, fused_key) - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: # Cannot use methods, as the self argument affects tracing - def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def pattern( + input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: result_rms = self.rmsnorm_matcher(input, weight) return self.quant_matcher(result_rms, scale)[0] - def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + def replacement( + input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) @@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): - def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + def __init__( + self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True + ) -> None: key = FusedRMSQuantKey( fused_add=True, quant=QuantKey( @@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result, _ = self.quant_matcher(result_rms, scale) @@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): weight: torch.Tensor, residual: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) @@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): epsilon: float, quant_dtype: torch.dtype, group_shape: GroupShape, - symmetric=True, + symmetric: bool = True, has_col_major_scales: bool = False, is_e8m0: bool = False, - ): + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=True, @@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 ) - def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result, scale = self.quant_matcher(result_rms) return result, residual, scale def replacement( input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) @@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): epsilon: float, quant_dtype: torch.dtype, group_shape: GroupShape, - symmetric=True, + symmetric: bool = True, has_col_major_scales: bool = False, is_e8m0: bool = False, - ): + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=False, @@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 ) - def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor): + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: result_rms = self.rmsnorm_matcher(input, weight) result, scale = self.quant_matcher(result_rms) return result, scale - def replacement(input: torch.Tensor, weight: torch.Tensor): + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) @@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): epsilon: float, quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + symmetric: bool = True, + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=False, @@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor): + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: result_rms = self.rmsnorm_matcher(input, weight) # result, scale - return self.quant_matcher(result_rms) + return self.quant_matcher(result_rms) # type: ignore[no-any-return] - def replacement(input: torch.Tensor, weight: torch.Tensor): + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) @@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): epsilon: float, quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + symmetric: bool = True, + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=True, @@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass): - def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: result_rms, residual = self.rmsnorm_matcher(input, weight, residual) result, scale = self.quant_matcher(result_rms) @@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): def replacement( input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # In case we're matching native rms-norm, conversions might be # optimized out. We convert here just to be safe. input = input.to(dtype=self.model_dtype) @@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): """ @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( @@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph): + def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) - def uuid(self) -> Any: + def uuid(self) -> str: return self.hash_source( self, RMSNormGroupQuantPattern, diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 6dcbbd85d..57448aa0b 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable +from typing import Any, ParamSpec import torch import torch._inductor.pattern_matcher as pm @@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8 from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) - +P = ParamSpec("P") FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC): layer: Attention, quant_key: QuantKey, dtype: torch.dtype, - ): + ) -> None: self.layer = layer self.layer_name = layer.layer_name self.num_heads = layer.num_heads @@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC): ) self.QUANT_OP = QUANT_OPS[self.quant_key] - def empty(self, *args, **kwargs): + def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor: kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) - def empty_quant(self, *args, **kwargs): + def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor: kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @staticmethod - def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): - def wrapped(*args, **kwargs): + def wrap_trace_fn( + trace_fn: Callable[P, fx.GraphModule], + *process_fx_fns: Callable[[fx.GraphModule], None], + ) -> Callable[P, fx.GraphModule]: + def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule: gm = trace_fn(*args, **kwargs) for process_fx in process_fx_fns: process_fx(gm) @@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC): return wrapped @staticmethod - def fx_view_to_reshape(gm: torch.fx.GraphModule): + def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None: from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(gm) @staticmethod - def remove_noop_permutes(gm: torch.fx.GraphModule): + def remove_noop_permutes(gm: torch.fx.GraphModule) -> None: for node in gm.graph.nodes: if not is_func(node, torch.ops.aten.permute.default): continue @@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC): node.replace_all_uses_with(node.args[0]) gm.graph.erase_node(node) - def register_if_supported(self, pm_pass: PatternMatcherPass): + def register_if_supported(self, pm_pass: PatternMatcherPass) -> None: if self.layer.impl.fused_output_quant_supported(self.quant_key): self._register(pm_pass) @abstractmethod - def _register(self, pm_pass: PatternMatcherPass): + def _register(self, pm_pass: PatternMatcherPass) -> None: raise NotImplementedError @@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): layer: Attention, dtype: torch.dtype, symmetric: bool = True, - ): + ) -> None: quant_key = QuantKey( dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric ) super().__init__(layer, quant_key, dtype) self.quant_matcher = MatcherQuantFP8(quant_key) - def _register(self, pm_pass: PatternMatcherPass): + def _register(self, pm_pass: PatternMatcherPass) -> None: def pattern( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, scale: torch.Tensor, - ): + ) -> torch.Tensor: at1 = auto_functionalized( ATTN_OP, query=q, @@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): v: torch.Tensor, output_attn: torch.Tensor, scale: torch.Tensor, - ): + ) -> torch.Tensor: # attn output in quant_dtype output_attn = torch.ops.aten.full.default( [q.shape[0], self.num_heads, self.head_size], @@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): will be passed into Attention op as the `output_scale` argument. """ - def __init__(self, layer: Attention, dtype: torch.dtype): + def __init__(self, layer: Attention, dtype: torch.dtype) -> None: super().__init__(layer, kNvfp4Quant, dtype) - def _register(self, pm_pass: PatternMatcherPass): + def _register(self, pm_pass: PatternMatcherPass) -> None: def pattern( q: torch.Tensor, k: torch.Tensor, @@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): output_quant: torch.Tensor, output_scale: torch.Tensor, input_scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: at1 = auto_functionalized( ATTN_OP, query=q, @@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): output_quant: torch.Tensor, output_scale: torch.Tensor, input_scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: # attention output in quant_dtype output_attn = torch.ops.aten.full.default( [q.shape[0], self.num_heads, self.head_size // 2], @@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass): """ @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") @@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass): self.matched_count = self.patterns.apply(graph) logger.debug("Fused quant onto %s attention nodes", self.matched_count) - def uuid(self): + def uuid(self) -> str: return VllmInductorPass.hash_source( self, AttentionQuantPattern, diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 93b2612f2..21723b6d3 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc] This is defined as a convenience and should work in most cases. """ - def uuid(self) -> Any: + def uuid(self) -> str: """ Provide a unique identifier for the pass, used in Inductor code cache. This should depend on the pass implementation, so that changes to the diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 7301aa3e5..eda12180d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from typing import Any import torch from torch._higher_order_ops import auto_functionalized @@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default class MatcherCustomOp(ABC): - def __init__(self, enabled: bool): + def __init__(self, enabled: bool) -> None: config = get_current_vllm_config() self.model_dtype = config.model_config.dtype if config.model_config else None self.device = config.device_config.device if config.device_config else None @@ -56,24 +57,24 @@ class MatcherCustomOp(ABC): self.forward = self.forward_custom if enabled else self.forward_native @abstractmethod - def forward_custom(self, *args, **kws): + def forward_custom(self, *args: Any, **kwargs: Any) -> Any: pass @abstractmethod - def forward_native(self, *args, **kws): + def forward_native(self, *args: Any, **kwargs: Any) -> Any: pass - def __call__(self, *args, **kws): - return self.forward(*args, **kws) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.forward(*args, **kwargs) - def empty(self, *args, **kws): - return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) + def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs) - def empty_int64(self, *args, **kws): - return torch.empty(*args, dtype=torch.int64, device=self.device, **kws) + def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs) - def empty_f32(self, *args, **kws): - return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) + def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor: + return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs) def inputs(self) -> list[torch.Tensor]: """Utility for inputs to the pattern""" @@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp): epsilon: float, enabled: bool | None = None, match_rocm_aiter: bool = False, - ): + ) -> None: if enabled is None: enabled = RMSNorm.enabled() @@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp): if match_rocm_aiter: self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op() - def inputs(self): + def inputs(self) -> list[torch.Tensor]: input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) weight = self.empty(16) return [input, weight] @@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): epsilon: float, enabled: bool | None = None, match_rocm_aiter: bool = False, - ): + ) -> None: if enabled is None: enabled = RMSNorm.enabled() @@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): if match_rocm_aiter: self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op() - def inputs(self): + def inputs(self) -> list[torch.Tensor]: input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) weight = self.empty(16) residual = self.empty(5, 16) @@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - return self._rmsnorm_op( + return self._rmsnorm_op( # type: ignore[no-any-return] x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon ) @@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp): has_col_major_scales: bool = False, is_e8m0: bool = False, match_rocm_aiter: bool = False, - ): + ) -> None: if enabled is None: enabled = QuantFP8.enabled() @@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp): ) -> tuple[torch.Tensor, torch.Tensor]: quant_key_group_shape = self.quant_key.scale.group_shape if quant_key_group_shape == GroupShape.PER_TOKEN: - return self.QUANT_OP( + return self.QUANT_OP( # type: ignore[no-any-return] x=input, quant_dtype=self.quant_key.dtype, scale=scale, ) else: - return self.QUANT_OP(input, quant_key_group_shape.col) + return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return] def forward_custom( self, @@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp): input: torch.Tensor, scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - return self.quant_fp8(input, scale) + return self.quant_fp8(input, scale) # type: ignore[no-any-return] - def make_scale(self, input: torch.Tensor, transposed: bool = False): + def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor: normalized_group_shape = _normalize_quant_group_shape( input, self.quant_key.scale.group_shape ) @@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp): class MatcherSiluAndMul(MatcherCustomOp): - def __init__(self, enabled: bool | None = None): + def __init__(self, enabled: bool | None = None) -> None: if enabled is None: enabled = SiluAndMul.enabled() super().__init__(enabled) diff --git a/vllm/compilation/qk_norm_rope_fusion.py b/vllm/compilation/qk_norm_rope_fusion.py index 794cd8e3f..bc95b7238 100644 --- a/vllm/compilation/qk_norm_rope_fusion.py +++ b/vllm/compilation/qk_norm_rope_fusion.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from typing import ParamSpec import torch import torch._inductor.pattern_matcher as pm @@ -23,6 +24,8 @@ logger = init_logger(__name__) FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default +P = ParamSpec("P") + class QkNormRopePattern: """ @@ -72,7 +75,7 @@ class QkNormRopePattern: use_flashinfer=self.rope_flashinfer, ) - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: # Sample inputs to help pattern tracing T = 5 qkv = empty_bf16(T, self.q_size + 2 * self.kv_size) @@ -92,8 +95,11 @@ class QkNormRopePattern: ] @staticmethod - def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): - def wrapped(*args, **kwargs): + def wrap_trace_fn( + trace_fn: Callable[P, fx.GraphModule], + *process_fx_fns: Callable[[fx.GraphModule], None], + ) -> Callable[P, fx.GraphModule]: + def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule: gm = trace_fn(*args, **kwargs) for process_fx in process_fx_fns: process_fx(gm) @@ -103,19 +109,19 @@ class QkNormRopePattern: return wrapped @staticmethod - def fx_view_to_reshape(gm: torch.fx.GraphModule): + def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None: from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(gm) - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( qkv: torch.Tensor, positions: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos_sin_cache: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # split qkv -> q,k,v q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -143,7 +149,7 @@ class QkNormRopePattern: q_weight: torch.Tensor, k_weight: torch.Tensor, cos_sin_cache: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Run fused qk_norm_rope op result = auto_functionalized( FUSED_QK_ROPE_OP, @@ -162,7 +168,7 @@ class QkNormRopePattern: result_qkv = result[1] # Split back to q,k,v and return - return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return] # NOTE: use fx_view_to_reshape to unify view/reshape to simplify # pattern and increase matching opportunities @@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass): """Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.""" @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="qk_norm_rope_fusion_pass" @@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass): self.matched_count = self.patterns.apply(graph) logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count) - def uuid(self): + def uuid(self) -> str: return VllmInductorPass.hash_source(self, QkNormRopePattern) diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index f66bb76b9..7a300cf50 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any import torch import torch._inductor.pattern_matcher as pm @@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): quant_dtype: torch.dtype, match_aiter_quant: bool = True, group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + symmetric: bool = True, + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=False, @@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): super().__init__(epsilon, key, match_aiter_quant) - def register(self, pm_pass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: result_rms = self.rmsnorm_matcher(input, weight) result, scale = self.quant_matcher(result_rms) return result, scale @@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): def replacement( input: torch.Tensor, weight: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: result = self.FUSED_OP( x=input, weight=weight, @@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): quant_dtype: torch.dtype, match_aiter_quant: bool = True, group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True, - ): + symmetric: bool = True, + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=True, @@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): super().__init__(epsilon, key, match_aiter_quant) - def register(self, pm_pass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) result, scale = self.quant_matcher(result_rms) @@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern): def replacement( input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: result = self.FUSED_OP( x=input, residual=residual, @@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape, match_aiter_quant: bool = True, - symmetric=True, - ): + symmetric: bool = True, + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=False, @@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): super().__init__(epsilon, key, match_aiter_quant) - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: result_rms = self.rmsnorm_matcher(input, weight) result, scale = self.quant_matcher(result_rms) return result, scale @@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): def replacement( input: torch.Tensor, weight: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: at = self.FUSED_OP( x=input, weight=weight, @@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape, match_aiter_quant: bool = True, - symmetric=True, - ): + symmetric: bool = True, + ) -> None: scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey( fused_add=True, @@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): super().__init__(epsilon, key, match_aiter_quant) - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) result, scale = self.quant_matcher(result_rms) @@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: at = self.FUSED_OP( x=input, residual=residual, @@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass): """ @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( @@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph): + def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) - def uuid(self) -> Any: + def uuid(self) -> str: fusion_patterns = [ AiterRMSNormDynamicQuantPattern, AiterFusedAddRMSNormDynamicQuantPattern, @@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op() - def __init__(self, quant_op: OpOverload): + def __init__(self, quant_op: OpOverload) -> None: self.silu_and_mul_matcher = MatcherSiluAndMul() self.quant_op = quant_op - def register(self, pm_pass: PatternMatcherPass): + def get_inputs(self) -> list[torch.Tensor]: + return [ + self.silu_and_mul_matcher.inputs()[0], + ] + + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: at1 = self.silu_and_mul_matcher(input) at2 = self.quant_op(at1, 128) return at2[0], at2[1] def replacement( input: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) return at[0], at[1] - inputs = [ - self.silu_and_mul_matcher.inputs()[0], - ] - - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): @@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP] @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( @@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: torch.fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) - def uuid(self): + def uuid(self) -> str: fusion_patterns = [ ActivationQuantPattern, AiterSiluMulFp8GroupQuantPattern, diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index bf81a62f2..34ff2ab47 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools +from collections.abc import Callable, Sequence +from typing import Any import torch import torch._inductor.pattern_matcher as pm @@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) -def get_first_out_wrapper(fn): +def get_first_out_wrapper( + fn: Callable[..., Sequence[torch.Tensor]], +) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def wrapper(*args): + def wrapper(*args: Any) -> torch.Tensor: return fn(*args)[0] return wrapper @@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherRMSNorm(epsilon) - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) @@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def pattern( input: torch.Tensor, arg3_1: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(input) rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1) @@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def replacement( input: torch.Tensor, arg3_1: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(input) rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1) @@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None: super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): rms_norm_weights, ] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( residual: torch.Tensor, mm_1: torch.Tensor, @@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): epsilon: float, dtype: torch.dtype, device: str | None, - ): + ) -> None: super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) return [input, weight, scale] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(input) rms = self.rmsnorm_matcher(all_reduce, weight) quant, _ = self.quant_matcher(rms, scale) @@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(input) rms = self.rmsnorm_matcher(reduce_scatter, weight) quant, _ = self.quant_matcher(rms, scale) @@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None: super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) - def get_inputs(self): + def get_inputs(self) -> list[torch.Tensor]: mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): return [residual, mm_1, rms_norm_weights, scale] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( residual: torch.Tensor, mm_1: torch.Tensor, @@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): """ @enable_fake_mode - def __init__(self, config: VllmConfig): + def __init__(self, config: VllmConfig) -> None: super().__init__(config) # Used to clean up redundant views created temporarily @@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph): + def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) # Clean up reshape nodes diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 840273348..c0f330408 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator): _TP = old_tp_group -def get_tensor_model_parallel_world_size(): +def get_tensor_model_parallel_world_size() -> int: """Return world size for the tensor model parallel group.""" return get_tp_group().world_size -def get_tensor_model_parallel_rank(): +def get_tensor_model_parallel_rank() -> int: """Return my rank for the tensor model parallel group.""" return get_tp_group().rank_in_group -def get_decode_context_model_parallel_world_size(): +def get_decode_context_model_parallel_world_size() -> int: """Return world size for the decode context model parallel group.""" return get_dcp_group().world_size -def get_decode_context_model_parallel_rank(): +def get_decode_context_model_parallel_rank() -> int: """Return my rank for the decode context model parallel group.""" return get_dcp_group().rank_in_group diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 452b87ea4..44a510228 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding from .xdrope import XDRotaryEmbedding from .yarn_scaling_rope import YaRNScalingRotaryEmbedding -_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} +_ROPE_DICT: dict[tuple[Any, ...], RotaryEmbedding] = {} + +__all__ = ["RotaryEmbedding"] def get_rope(