[torch.compile] Reorganize vllm/compilation and tests/compile (0/N for vLLM IR) (#33731)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: ProExpertProg <luka.govedic@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -22,11 +22,6 @@ from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
from vllm.compilation.partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.config.utils import Range, hash_factors
|
||||
@@ -44,8 +39,12 @@ from .compiler_interface import (
|
||||
is_compile_cache_enabled,
|
||||
)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .pass_manager import PostGradPassManager
|
||||
from .partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
should_split,
|
||||
)
|
||||
from .passes.inductor_pass import InductorPass, pass_context
|
||||
from .passes.pass_manager import PostGradPassManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
@@ -24,12 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
flashinfer_comm: ModuleType | None = None
|
||||
if find_spec("flashinfer"):
|
||||
try:
|
||||
@@ -45,406 +46,6 @@ logger = init_logger(__name__)
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(
|
||||
input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(
|
||||
all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass"
|
||||
)
|
||||
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
|
||||
@@ -623,6 +224,15 @@ class FlashInferFusedAllReduceParams:
|
||||
}
|
||||
|
||||
|
||||
# TODO(luka): unify
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
@@ -22,11 +22,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .fx_utils import is_func
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..fx_utils import is_func
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
P = ParamSpec("P")
|
||||
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(
|
||||
input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(
|
||||
all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass"
|
||||
)
|
||||
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
@@ -15,10 +15,10 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from .fusion import empty_bf16, empty_fp32, empty_i64
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -25,13 +25,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
)
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -9,7 +9,6 @@ from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -19,17 +18,18 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..activation_quant_fusion import ActivationQuantPattern
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
MatcherSiluAndMul,
|
||||
)
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -20,10 +20,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -8,38 +8,39 @@ from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
from .fusion.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
from .fusion.act_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion.attn_quant_fusion import AttnFusionPass
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
from .fusion.collective_fusion import AsyncTPPass
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .utility.fix_functionalization import FixFunctionalizationPass
|
||||
from .utility.noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
0
vllm/compilation/passes/utility/__init__.py
Normal file
0
vllm/compilation/passes/utility/__init__.py
Normal file
@@ -10,8 +10,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -9,8 +9,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
Reference in New Issue
Block a user