[torch.compile] Reorganize vllm/compilation and tests/compile (0/N for vLLM IR) (#33731)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: ProExpertProg <luka.govedic@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2026-02-06 07:19:49 -05:00
committed by GitHub
parent f79d9dce16
commit ac32e66cf9
47 changed files with 717 additions and 651 deletions

View File

@@ -22,11 +22,6 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._logging._internal import trace_structured
import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
inductor_partition_rule_context,
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import Range, hash_factors
@@ -44,8 +39,12 @@ from .compiler_interface import (
is_compile_cache_enabled,
)
from .counter import compilation_counter
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
from .partition_rules import (
inductor_partition_rule_context,
should_split,
)
from .passes.inductor_pass import InductorPass, pass_context
from .passes.pass_manager import PostGradPassManager
logger = init_logger(__name__)

View File

View File

@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)

View File

@@ -8,7 +8,6 @@ import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.config.utils import Range
@@ -24,12 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
flashinfer_comm: ModuleType | None = None
if find_spec("flashinfer"):
try:
@@ -45,406 +46,6 @@ logger = init_logger(__name__)
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class GEMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [mul, mm_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
mm = torch.ops.aten.mm.default(mul, mm_weight)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherGEMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [x, weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return torch.ops.aten.mm.default(all_gather, weight)
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x,
[weight],
gather_dim=0,
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class ScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [input, mm_weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
scaled_mm = torch.ops.aten._scaled_mm.default(
input,
mat2=mat2,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
scaled_mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [x, weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
return torch.ops.aten._scaled_mm.default(
all_gather,
mat2=weight,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class CutlassScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=cutlass_mm_output,
a=input,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
cutlass_scaled_mm[1],
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherCutlassScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
s2 = weight.shape[1]
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
return [x, weight, scale_a, scale_b, output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=output,
a=all_gather,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
return cutlass_scaled_mm[1]
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="async_tp_pass"
)
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if self.model_dtype == torch.bfloat16:
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
@@ -623,6 +224,15 @@ class FlashInferFusedAllReduceParams:
}
# TODO(luka): unify
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class AllReduceRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)

View File

@@ -22,11 +22,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .fx_utils import is_func
from .inductor_pass import enable_fake_mode
from ..fx_utils import is_func
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
logger = init_logger(__name__)
P = ParamSpec("P")

View File

@@ -0,0 +1,423 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
class GEMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [mul, mm_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
mm = torch.ops.aten.mm.default(mul, mm_weight)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherGEMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [x, weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return torch.ops.aten.mm.default(all_gather, weight)
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x,
[weight],
gather_dim=0,
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class ScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [input, mm_weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
scaled_mm = torch.ops.aten._scaled_mm.default(
input,
mat2=mat2,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
scaled_mm,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
return [x, weight, scale_a, scale_b]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
return torch.ops.aten._scaled_mm.default(
all_gather,
mat2=weight,
scale_a=scale_a,
scale_b=scale_b,
bias=None,
scale_result=None,
out_dtype=self.dtype,
)
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class CutlassScaledMMReduceScatterPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
mm_weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=cutlass_mm_output,
a=input,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
cutlass_scaled_mm[1],
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
return reduce_scatter
def replacement(
input: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)
return gemm_rs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AllGatherCutlassScaledMMPattern(BasePattern):
def get_inputs(self) -> list[torch.Tensor]:
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
weight = (
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
s1 = x.shape[0] * self.tp_size
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
s2 = weight.shape[1]
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
return [x, weight, scale_a, scale_b, output]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
all_gather = torch.ops.vllm.all_gather.default(
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
)
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.cutlass_scaled_mm.default,
out=output,
a=all_gather,
b=weight,
a_scales=scale_a,
b_scales=scale_b,
bias=None,
)
return cutlass_scaled_mm[1]
def replacement(
x: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
x,
[weight],
scale_a,
[scale_b],
gather_dim=0,
biases=[None],
result_scales=[None],
out_dtypes=[self.dtype],
use_fast_accum=[False],
group_name=self.tp.device_group.group_name,
)
return mm_outputs
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
# Enable symmetric memory for the TP process group
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="async_tp_pass"
)
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
# These fusions are enabled only for bfloat16 models because
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
# only supports bfloat16 as the output dtype.
if self.model_dtype == torch.bfloat16:
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
self.patterns
)
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
self.patterns
)
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

View File

@@ -15,10 +15,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from .fusion import empty_bf16, empty_fp32, empty_i64
from .inductor_pass import enable_fake_mode
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
logger = init_logger(__name__)

View File

@@ -25,13 +25,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()

View File

@@ -9,7 +9,6 @@ from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -19,17 +18,18 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from .fusion import (
FusedRMSQuantKey,
)
from .inductor_pass import enable_fake_mode
from ..activation_quant_fusion import ActivationQuantPattern
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
MatcherSiluAndMul,
)
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .rms_quant_fusion import (
FusedRMSQuantKey,
)
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()

View File

@@ -20,10 +20,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .noop_elimination import NoOpEliminationPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)

View File

@@ -8,38 +8,39 @@ from torch import fx as fx
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
from .fusion.rocm_aiter_fusion import (
RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass,
)
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
from .sequence_parallelism import SequenceParallelismPass
from .fusion.act_quant_fusion import ActivationQuantFusionPass
from .fusion.attn_quant_fusion import AttnFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.sequence_parallelism import SequenceParallelismPass
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import (
CustomGraphPass,
InductorPass,
get_pass_context,
)
from .noop_elimination import NoOpEliminationPass
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)

View File

@@ -10,8 +10,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)

View File

@@ -9,8 +9,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true
from vllm.logger import init_logger
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from ..vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):