[Misc][BE] Type coverage for vllm/compilation [3/3] (#31748)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
@@ -52,7 +53,7 @@ class ActivationQuantPattern(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
):
|
||||
) -> None:
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
@@ -68,12 +69,12 @@ class ActivationQuantPattern(ABC):
|
||||
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -82,15 +83,22 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kFp8StaticTensorSym)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
scale = self.quant_matcher.inputs()[1]
|
||||
return [
|
||||
*self.silu_and_mul_matcher.inputs(), # input
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
result_quant = self.quant_matcher(result_silu_mul, scale)
|
||||
return result_quant[0]
|
||||
@@ -98,7 +106,7 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
d = input.shape[-1] // 2
|
||||
output_shape = input.shape[:-1] + (d,)
|
||||
result = torch.empty(
|
||||
@@ -109,13 +117,10 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
)
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
*self.silu_and_mul_matcher.inputs(), # input
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
inps = self.get_inputs()
|
||||
pattern(*inps)
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
@@ -123,16 +128,23 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kNvfp4Quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
result = self.empty_quant(5, 32)
|
||||
output_scale = empty_i32(128, 4)
|
||||
input_ = empty_bf16(5, 64)
|
||||
scale = empty_fp32(1, 1)
|
||||
return [result, output_scale, input_, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
@@ -148,7 +160,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@@ -158,14 +170,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 32), # result
|
||||
empty_i32(128, 4), # output_scale
|
||||
empty_bf16(5, 64), # input
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
@@ -179,7 +184,7 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
@@ -196,11 +201,11 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
ActivationQuantPattern,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from importlib.util import find_spec
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -33,15 +34,15 @@ if find_spec("flashinfer"):
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm
|
||||
|
||||
flashinfer_comm = (
|
||||
flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef]
|
||||
flashinfer_comm
|
||||
if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
|
||||
else None
|
||||
)
|
||||
except ImportError:
|
||||
flashinfer_comm = None
|
||||
flashinfer_comm = None # type: ignore[assignment]
|
||||
else:
|
||||
flashinfer_comm = None
|
||||
flashinfer_comm = None # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -58,13 +59,13 @@ class BasePattern:
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
@@ -74,7 +75,7 @@ class GEMMReduceScatterPattern(BasePattern):
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
@@ -91,17 +92,17 @@ class GEMMReduceScatterPattern(BasePattern):
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
@@ -111,9 +112,7 @@ class AllGatherGEMMPattern(BasePattern):
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
@@ -128,7 +127,7 @@ class AllGatherGEMMPattern(BasePattern):
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
@@ -139,7 +138,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
@@ -196,7 +195,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
@@ -211,7 +210,7 @@ class AllGatherScaledMMPattern(BasePattern):
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -258,7 +257,7 @@ class AllGatherScaledMMPattern(BasePattern):
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
@@ -271,7 +270,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -331,7 +330,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
@@ -349,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -400,7 +399,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
@@ -445,7 +444,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
return compile_range.is_single_size() and compile_range.end % tp_size == 0
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
@@ -512,11 +511,13 @@ if flashinfer_comm is not None:
|
||||
f"max token num {max_token_num} * hidden size {hidden_size} * "
|
||||
f"element size {element_size}"
|
||||
)
|
||||
device_capability = current_platform.get_device_capability().to_int()
|
||||
curr_device = current_platform.get_device_capability()
|
||||
device_capability = curr_device.to_int() if curr_device is not None else None
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, {}
|
||||
device_capability, # type: ignore[arg-type]
|
||||
{},
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size is specified
|
||||
use_oneshot = (
|
||||
@@ -606,7 +607,7 @@ class FlashInferFusedAllReduceParams:
|
||||
world_size: int,
|
||||
use_fp32_lamport: bool = False,
|
||||
max_token_num: int = 1024,
|
||||
):
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.use_fp32_lamport = use_fp32_lamport
|
||||
@@ -615,7 +616,7 @@ class FlashInferFusedAllReduceParams:
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self):
|
||||
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
|
||||
return {
|
||||
"world_rank": self.rank,
|
||||
"world_size": self.world_size,
|
||||
@@ -639,26 +640,30 @@ class AllReduceRMSNormPattern(BasePattern):
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
||||
|
||||
return rms, allreduce_output
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
@@ -694,27 +699,29 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
return rms, residual
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -739,8 +746,8 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern),
|
||||
first_return_only(replacement),
|
||||
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
@@ -761,7 +768,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
@@ -769,25 +776,27 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
@@ -812,7 +821,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
return allreduce[4], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
@@ -830,7 +839,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
@@ -839,20 +848,20 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight, scale]
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
@@ -864,7 +873,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
@@ -886,7 +895,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
return allreduce[4], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
@@ -904,31 +913,31 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [input, quant_result, weight, input_global_scale, output_scale]
|
||||
return [input, quant_result, weight, input_global_scale, output_scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
@@ -948,7 +957,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
@@ -972,7 +981,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
return allreduce[4], allreduce[1], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
@@ -990,33 +999,33 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
output_scale,
|
||||
weight,
|
||||
input_global_scale,
|
||||
]
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
output_scale,
|
||||
weight,
|
||||
input_global_scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
@@ -1024,7 +1033,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
@@ -1045,7 +1054,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -1066,12 +1075,12 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
return allreduce[4], allreduce[2], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.disabled = True
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -1122,7 +1131,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
)
|
||||
|
||||
self.ipc_handles, workspace_tensor = (
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc]
|
||||
tp_rank=rank,
|
||||
tp_size=self.tp_size,
|
||||
max_token_num=self.max_token_num,
|
||||
@@ -1145,7 +1154,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@enable_fake_mode
|
||||
def register_patterns(self):
|
||||
def register_patterns(self) -> None:
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
@@ -1198,7 +1207,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
if self.disabled:
|
||||
logger.debug("AllReduceFusionPass disabled")
|
||||
return
|
||||
@@ -1206,7 +1215,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "disabled", True):
|
||||
return
|
||||
if flashinfer_comm is not None:
|
||||
|
||||
@@ -38,19 +38,19 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args, **kwargs):
|
||||
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
def empty_i64(*args, **kwargs):
|
||||
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class FusedRMSQuantKey(NamedTuple):
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)"
|
||||
@@ -121,7 +121,7 @@ class RMSNormQuantPattern:
|
||||
key: FusedRMSQuantKey,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
config = get_current_vllm_config()
|
||||
@@ -141,7 +141,9 @@ class RMSNormQuantPattern:
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
fused_key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(
|
||||
@@ -150,13 +152,17 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
)
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
@@ -187,7 +193,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(
|
||||
@@ -196,13 +204,13 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
@@ -213,7 +221,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
@@ -253,10 +261,10 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
symmetric: bool = True,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
@@ -269,15 +277,17 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
@@ -315,10 +325,10 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric=True,
|
||||
symmetric: bool = True,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
@@ -329,13 +339,17 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
@@ -375,8 +389,8 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
@@ -384,13 +398,17 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return self.quant_matcher(result_rms)
|
||||
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
@@ -426,8 +444,8 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
@@ -435,8 +453,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
@@ -444,7 +464,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
@@ -481,7 +501,7 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
@@ -533,11 +553,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
def uuid(self) -> str:
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormGroupQuantPattern,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -28,7 +29,7 @@ from .matcher_utils import MatcherQuantFP8
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
@@ -47,7 +48,7 @@ class AttentionQuantPattern(ABC):
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
) -> None:
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
@@ -61,17 +62,20 @@ class AttentionQuantPattern(ABC):
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args, **kwargs):
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
@@ -81,13 +85,13 @@ class AttentionQuantPattern(ABC):
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule):
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
@@ -100,12 +104,12 @@ class AttentionQuantPattern(ABC):
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -124,21 +128,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
@@ -161,7 +165,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
@@ -212,10 +216,10 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype):
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
|
||||
super().__init__(layer, kNvfp4Quant, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -224,7 +228,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
@@ -256,7 +260,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
@@ -318,7 +322,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
@@ -350,7 +354,7 @@ class AttnFusionPass(VllmPatternMatcherPass):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
AttentionQuantPattern,
|
||||
|
||||
@@ -68,7 +68,7 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> Any:
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
@@ -47,7 +48,7 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
class MatcherCustomOp(ABC):
|
||||
def __init__(self, enabled: bool):
|
||||
def __init__(self, enabled: bool) -> None:
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
@@ -56,24 +57,24 @@ class MatcherCustomOp(ABC):
|
||||
self.forward = self.forward_custom if enabled else self.forward_native
|
||||
|
||||
@abstractmethod
|
||||
def forward_custom(self, *args, **kws):
|
||||
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_native(self, *args, **kws):
|
||||
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kws):
|
||||
return self.forward(*args, **kws)
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def empty_int64(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)
|
||||
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
|
||||
|
||||
def empty_f32(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
|
||||
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
"""Utility for inputs to the pattern"""
|
||||
@@ -157,7 +158,7 @@ class MatcherRMSNorm(MatcherCustomOp):
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
@@ -169,7 +170,7 @@ class MatcherRMSNorm(MatcherCustomOp):
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
||||
|
||||
def inputs(self):
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
@@ -220,7 +221,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
@@ -233,7 +234,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
|
||||
|
||||
def inputs(self):
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(5, 16)
|
||||
@@ -245,7 +246,7 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._rmsnorm_op(
|
||||
return self._rmsnorm_op( # type: ignore[no-any-return]
|
||||
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
|
||||
)
|
||||
|
||||
@@ -287,7 +288,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
match_rocm_aiter: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
@@ -340,13 +341,13 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
quant_key_group_shape = self.quant_key.scale.group_shape
|
||||
if quant_key_group_shape == GroupShape.PER_TOKEN:
|
||||
return self.QUANT_OP(
|
||||
return self.QUANT_OP( # type: ignore[no-any-return]
|
||||
x=input,
|
||||
quant_dtype=self.quant_key.dtype,
|
||||
scale=scale,
|
||||
)
|
||||
else:
|
||||
return self.QUANT_OP(input, quant_key_group_shape.col)
|
||||
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
@@ -400,9 +401,9 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale)
|
||||
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
|
||||
|
||||
def make_scale(self, input: torch.Tensor, transposed: bool = False):
|
||||
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
@@ -427,7 +428,7 @@ class MatcherQuantFP8(MatcherCustomOp):
|
||||
|
||||
|
||||
class MatcherSiluAndMul(MatcherCustomOp):
|
||||
def __init__(self, enabled: bool | None = None):
|
||||
def __init__(self, enabled: bool | None = None) -> None:
|
||||
if enabled is None:
|
||||
enabled = SiluAndMul.enabled()
|
||||
super().__init__(enabled)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -23,6 +24,8 @@ logger = init_logger(__name__)
|
||||
|
||||
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class QkNormRopePattern:
|
||||
"""
|
||||
@@ -72,7 +75,7 @@ class QkNormRopePattern:
|
||||
use_flashinfer=self.rope_flashinfer,
|
||||
)
|
||||
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
|
||||
@@ -92,8 +95,11 @@ class QkNormRopePattern:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
@@ -103,19 +109,19 @@ class QkNormRopePattern:
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# split qkv -> q,k,v
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
@@ -143,7 +149,7 @@ class QkNormRopePattern:
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Run fused qk_norm_rope op
|
||||
result = auto_functionalized(
|
||||
FUSED_QK_ROPE_OP,
|
||||
@@ -162,7 +168,7 @@ class QkNormRopePattern:
|
||||
result_qkv = result[1]
|
||||
|
||||
# Split back to q,k,v and return
|
||||
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
|
||||
|
||||
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
@@ -182,7 +188,7 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
|
||||
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qk_norm_rope_fusion_pass"
|
||||
@@ -234,5 +240,5 @@ class QKNormRoPEFusionPass(VllmPatternMatcherPass):
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, QkNormRopePattern)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -65,8 +64,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
@@ -75,11 +74,11 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
@@ -87,7 +86,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
@@ -117,8 +116,8 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
@@ -127,12 +126,12 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
@@ -140,7 +139,7 @@ class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
@@ -174,8 +173,8 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric=True,
|
||||
):
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
@@ -184,11 +183,11 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
@@ -196,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
@@ -225,8 +224,8 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric=True,
|
||||
):
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
@@ -235,12 +234,12 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
@@ -250,7 +249,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
@@ -275,7 +274,7 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
@@ -311,11 +310,11 @@ class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
@@ -333,29 +332,32 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload):
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
self.silu_and_mul_matcher.inputs()[0],
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||
return at[0], at[1]
|
||||
|
||||
inputs = [
|
||||
self.silu_and_mul_matcher.inputs()[0],
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
@@ -374,7 +376,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
@@ -387,11 +389,11 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
ActivationQuantPattern,
|
||||
AiterSiluMulFp8GroupQuantPattern,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@@ -26,9 +28,11 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_first_out_wrapper(fn):
|
||||
def get_first_out_wrapper(
|
||||
fn: Callable[..., Sequence[torch.Tensor]],
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args):
|
||||
def wrapper(*args: Any) -> torch.Tensor:
|
||||
return fn(*args)[0]
|
||||
|
||||
return wrapper
|
||||
@@ -68,7 +72,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
@@ -78,7 +82,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
|
||||
@@ -87,7 +91,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
@@ -100,11 +104,11 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
@@ -116,7 +120,7 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
@@ -163,23 +167,23 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
@@ -189,7 +193,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
@@ -203,12 +207,12 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
@@ -216,7 +220,7 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
return [residual, mm_1, rms_norm_weights, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
@@ -302,7 +306,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Used to clean up redundant views created temporarily
|
||||
@@ -357,7 +361,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph):
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
# Clean up reshape nodes
|
||||
|
||||
@@ -1529,22 +1529,22 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
||||
_TP = old_tp_group
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
def get_tensor_model_parallel_world_size() -> int:
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_tp_group().world_size
|
||||
|
||||
|
||||
def get_tensor_model_parallel_rank():
|
||||
def get_tensor_model_parallel_rank() -> int:
|
||||
"""Return my rank for the tensor model parallel group."""
|
||||
return get_tp_group().rank_in_group
|
||||
|
||||
|
||||
def get_decode_context_model_parallel_world_size():
|
||||
def get_decode_context_model_parallel_world_size() -> int:
|
||||
"""Return world size for the decode context model parallel group."""
|
||||
return get_dcp_group().world_size
|
||||
|
||||
|
||||
def get_decode_context_model_parallel_rank():
|
||||
def get_decode_context_model_parallel_rank() -> int:
|
||||
"""Return my rank for the decode context model parallel group."""
|
||||
return get_dcp_group().rank_in_group
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
|
||||
from .xdrope import XDRotaryEmbedding
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
||||
|
||||
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
|
||||
_ROPE_DICT: dict[tuple[Any, ...], RotaryEmbedding] = {}
|
||||
|
||||
__all__ = ["RotaryEmbedding"]
|
||||
|
||||
|
||||
def get_rope(
|
||||
|
||||
Reference in New Issue
Block a user