[Misc][BE] Type coverage for vllm/compilation [3/3] (#31748)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-01-12 11:24:38 -08:00
committed by GitHub
parent 08e8e99ce7
commit ad8818bb5e
11 changed files with 331 additions and 278 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC):
def __init__(
self,
quant_key: QuantKey,
):
) -> None:
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC):
self.silu_and_mul_matcher = MatcherSiluAndMul()
def empty_quant(self, *args, **kwargs):
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@abstractmethod
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern
"""
def __init__(self):
def __init__(self) -> None:
super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs(self) -> list[torch.Tensor]:
scale = self.quant_matcher.inputs()[1]
return [
*self.silu_and_mul_matcher.inputs(), # input
scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]
@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
def replacement(
input: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
)
return at[1]
inputs = [
*self.silu_and_mul_matcher.inputs(), # input
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
inps = self.get_inputs()
pattern(*inps)
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Nvfp4Quant Pattern
"""
def __init__(self):
def __init__(self) -> None:
super().__init__(kNvfp4Quant)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
output_scale = empty_i32(128, 4)
input_ = empty_bf16(5, 64)
scale = empty_fp32(1, 1)
return [result, output_scale, input_, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
output_scale: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at = auto_functionalized(
self.FUSED_OP,
result=result,
@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
)
return at[1], at[2]
inputs = [
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # input
empty_fp32(1, 1), # scale
]
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmPatternMatcherPass):
@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
ActivationQuantPattern,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec
from types import ModuleType
import torch
import torch._inductor.pattern_matcher as pm
@@ -33,15 +34,15 @@ if find_spec("flashinfer"):
try:
import flashinfer.comm as flashinfer_comm
flashinfer_comm = (
flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef]
flashinfer_comm
if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
else None
)
except ImportError:
flashinfer_comm = None
flashinfer_comm = None # type: ignore[assignment]
else:
flashinfer_comm = None
flashinfer_comm = None # type: ignore[assignment]
logger = init_logger(__name__)
@@ -58,13 +59,13 @@ class BasePattern:
class GEMMReduceScatterPattern(BasePattern):
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [mul, mm_weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
mm = torch.ops.aten.mm.default(mul, mm_weight)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm,
@@ -74,7 +75,7 @@ class GEMMReduceScatterPattern(BasePattern):
)
return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
@@ -91,17 +92,17 @@ class GEMMReduceScatterPattern(BasePattern):
class AllGatherGEMMPattern(BasePattern):
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [x, weight]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
@@ -111,9 +112,7 @@ class AllGatherGEMMPattern(BasePattern):
return torch.ops.aten.mm.default(all_gather, weight)
def replacement(
x: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x,
[weight],
@@ -128,7 +127,7 @@ class AllGatherGEMMPattern(BasePattern):
class ScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
@@ -139,7 +138,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [input, mm_weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
mat2: torch.Tensor,
@@ -196,7 +195,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
class AllGatherScaledMMPattern(BasePattern):
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
@@ -211,7 +210,7 @@ class AllGatherScaledMMPattern(BasePattern):
return [x, weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
@@ -258,7 +257,7 @@ class AllGatherScaledMMPattern(BasePattern):
class CutlassScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
@@ -271,7 +270,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
@@ -331,7 +330,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
class AllGatherCutlassScaledMMPattern(BasePattern):
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
@@ -349,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
return [x, weight, scale_a, scale_b, output]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
@@ -400,7 +399,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Enable symmetric memory for the TP process group
@@ -445,7 +444,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
return compile_range.is_single_size() and compile_range.end % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
@@ -512,11 +511,13 @@ if flashinfer_comm is not None:
f"max token num {max_token_num} * hidden size {hidden_size} * "
f"element size {element_size}"
)
device_capability = current_platform.get_device_capability().to_int()
curr_device = current_platform.get_device_capability()
device_capability = curr_device.to_int() if curr_device is not None else None
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, {}
device_capability, # type: ignore[arg-type]
{},
).get(world_size, None)
# Use one shot if no max size is specified
use_oneshot = (
@@ -606,7 +607,7 @@ class FlashInferFusedAllReduceParams:
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
):
) -> None:
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
@@ -615,7 +616,7 @@ class FlashInferFusedAllReduceParams:
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
return {
"world_rank": self.rank,
"world_size": self.world_size,
@@ -639,26 +640,30 @@ class AllReduceRMSNormPattern(BasePattern):
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(allreduce_output, weight)
return rms, allreduce_output
def replacement(input: torch.Tensor, weight: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
allreduce = auto_functionalized(
@@ -694,27 +699,29 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
):
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -739,8 +746,8 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern),
first_return_only(replacement),
first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call]
self.get_inputs(),
pm.fwd_only,
pm_pass,
@@ -761,7 +768,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
@@ -769,25 +776,27 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
@@ -812,7 +821,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
return allreduce[4], allreduce[1]
pm.register_replacement(
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
@@ -830,7 +839,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
@@ -839,20 +848,20 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
def get_inputs(self) -> list[torch.Tensor]:
input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight, scale]
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight, scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant, _ = self.quant_matcher(rms, scale)
@@ -864,7 +873,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
@@ -886,7 +895,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
return allreduce[4], allreduce[2]
pm.register_replacement(
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
@@ -904,31 +913,31 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [input, quant_result, weight, input_global_scale, output_scale]
return [input, quant_result, weight, input_global_scale, output_scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized(
@@ -948,7 +957,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
allreduce = auto_functionalized(
@@ -972,7 +981,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
return allreduce[4], allreduce[1], allreduce[5]
pm.register_replacement(
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
@@ -990,33 +999,33 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [
quant_result,
residual,
input,
output_scale,
weight,
input_global_scale,
]
return [
quant_result,
residual,
input,
output_scale,
weight,
input_global_scale,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
@@ -1024,7 +1033,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized(
@@ -1045,7 +1054,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -1066,12 +1075,12 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
return allreduce[4], allreduce[2], allreduce[5]
pm.register_replacement(
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllReduceFusionPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
@@ -1122,7 +1131,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
)
self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc]
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=self.max_token_num,
@@ -1145,7 +1154,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@enable_fake_mode
def register_patterns(self):
def register_patterns(self) -> None:
for epsilon in [1e-5, 1e-6]:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
@@ -1198,7 +1207,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
return compile_range.end <= self.max_token_num
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
if self.disabled:
logger.debug("AllReduceFusionPass disabled")
return
@@ -1206,7 +1215,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def __del__(self):
def __del__(self) -> None:
if getattr(self, "disabled", True):
return
if flashinfer_comm is not None:

View File

@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs):
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp32(*args, **kwargs):
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args, **kwargs):
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args, **kwargs):
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple):
quant: QuantKey
fused_add: bool
def __str__(self):
def __str__(self) -> str:
return (
f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)"
@@ -121,7 +121,7 @@ class RMSNormQuantPattern:
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
) -> None:
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
@@ -141,7 +141,9 @@ class RMSNormQuantPattern:
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
fused_key = FusedRMSQuantKey(
fused_add=False,
quant=QuantKey(
@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
# Cannot use methods, as the self argument affects tracing
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
def __init__(
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
) -> None:
key = FusedRMSQuantKey(
fused_add=True,
quant=QuantKey(
@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
symmetric: bool = True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
return result, residual, scale
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
symmetric: bool = True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
def replacement(input: torch.Tensor, weight: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return self.quant_matcher(result_rms)
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
def replacement(input: torch.Tensor, weight: torch.Tensor):
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
epsilon: float,
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
)
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
def uuid(self) -> str:
return self.hash_source(
self,
RMSNormGroupQuantPattern,

View File

@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
P = ParamSpec("P")
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC):
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
):
) -> None:
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC):
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args, **kwargs):
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args, **kwargs):
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC):
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule):
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC):
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass):
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
def _register(self, pm_pass: PatternMatcherPass) -> None:
raise NotImplementedError
@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer: Attention,
dtype: torch.dtype,
symmetric: bool = True,
):
) -> None:
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
at1 = auto_functionalized(
ATTN_OP,
query=q,
@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
):
) -> torch.Tensor:
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention, dtype: torch.dtype):
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Quant, dtype)
def _register(self, pm_pass: PatternMatcherPass):
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
q: torch.Tensor,
k: torch.Tensor,
@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
ATTN_OP,
query=q,
@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self):
def uuid(self) -> str:
return VllmInductorPass.hash_source(
self,
AttentionQuantPattern,

View File

@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> Any:
def uuid(self) -> str:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch._higher_order_ops import auto_functionalized
@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
def __init__(self, enabled: bool) -> None:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC):
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args, **kws):
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
def forward_native(self, *args, **kws):
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args, **kws):
return self.forward(*args, **kws)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def empty_int64(self, *args, **kws):
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp):
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
):
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp):
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self):
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
):
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
def inputs(self):
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self._rmsnorm_op(
return self._rmsnorm_op( # type: ignore[no-any-return]
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
)
@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp):
has_col_major_scales: bool = False,
is_e8m0: bool = False,
match_rocm_aiter: bool = False,
):
) -> None:
if enabled is None:
enabled = QuantFP8.enabled()
@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp):
) -> tuple[torch.Tensor, torch.Tensor]:
quant_key_group_shape = self.quant_key.scale.group_shape
if quant_key_group_shape == GroupShape.PER_TOKEN:
return self.QUANT_OP(
return self.QUANT_OP( # type: ignore[no-any-return]
x=input,
quant_dtype=self.quant_key.dtype,
scale=scale,
)
else:
return self.QUANT_OP(input, quant_key_group_shape.col)
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
def forward_custom(
self,
@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp):
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
def make_scale(self, input: torch.Tensor, transposed: bool = False):
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp):
class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None):
def __init__(self, enabled: bool | None = None) -> None:
if enabled is None:
enabled = SiluAndMul.enabled()
super().__init__(enabled)

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import ParamSpec
import torch
import torch._inductor.pattern_matcher as pm
@@ -23,6 +24,8 @@ logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
P = ParamSpec("P")
class QkNormRopePattern:
"""
@@ -72,7 +75,7 @@ class QkNormRopePattern:
use_flashinfer=self.rope_flashinfer,
)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
@@ -92,8 +95,11 @@ class QkNormRopePattern:
]
@staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
def wrap_trace_fn(
trace_fn: Callable[P, fx.GraphModule],
*process_fx_fns: Callable[[fx.GraphModule], None],
) -> Callable[P, fx.GraphModule]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
@@ -103,19 +109,19 @@ class QkNormRopePattern:
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -143,7 +149,7 @@ class QkNormRopePattern:
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Run fused qk_norm_rope op
result = auto_functionalized(
FUSED_QK_ROPE_OP,
@@ -162,7 +168,7 @@ class QkNormRopePattern:
result_qkv = result[1]
# Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass"
@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self):
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, QkNormRopePattern)

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
weight=weight,
@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
match_aiter_quant: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
def replacement(
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result = self.FUSED_OP(
x=input,
residual=residual,
@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=False,
@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result, scale = self.quant_matcher(result_rms)
return result, scale
@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
weight=weight,
@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
match_aiter_quant: bool = True,
symmetric=True,
):
symmetric: bool = True,
) -> None:
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
fused_add=True,
@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
super().__init__(epsilon, key, match_aiter_quant)
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.FUSED_OP(
x=input,
residual=residual,
@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> Any:
def uuid(self) -> str:
fusion_patterns = [
AiterRMSNormDynamicQuantPattern,
AiterFusedAddRMSNormDynamicQuantPattern,
@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload):
def __init__(self, quant_op: OpOverload) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
def register(self, pm_pass: PatternMatcherPass):
def get_inputs(self) -> list[torch.Tensor]:
return [
self.silu_and_mul_matcher.inputs()[0],
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
return at2[0], at2[1]
def replacement(
input: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
return at[0], at[1]
inputs = [
self.silu_and_mul_matcher.inputs()[0],
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
def uuid(self) -> str:
fusion_patterns = [
ActivationQuantPattern,
AiterSiluMulFp8GroupQuantPattern,

View File

@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable, Sequence
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
def get_first_out_wrapper(fn):
def get_first_out_wrapper(
fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def wrapper(*args):
def wrapper(*args: Any) -> torch.Tensor:
return fn(*args)[0]
return wrapper
@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def pattern(
input: torch.Tensor,
arg3_1: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def replacement(
input: torch.Tensor,
arg3_1: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
epsilon: float,
dtype: torch.dtype,
device: str | None,
):
) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight)
quant, _ = self.quant_matcher(rms, scale)
@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self):
def get_inputs(self) -> list[torch.Tensor]:
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
return [residual, mm_1, rms_norm_weights, scale]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Used to clean up redundant views created temporarily
@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Clean up reshape nodes

View File

@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
_TP = old_tp_group
def get_tensor_model_parallel_world_size():
def get_tensor_model_parallel_world_size() -> int:
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
def get_tensor_model_parallel_rank():
def get_tensor_model_parallel_rank() -> int:
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size():
def get_decode_context_model_parallel_world_size() -> int:
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank():
def get_decode_context_model_parallel_rank() -> int:
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group

View File

@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
_ROPE_DICT: dict[tuple[Any, ...], RotaryEmbedding] = {}
__all__ = ["RotaryEmbedding"]
def get_rope(