Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,14 +5,21 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
|
||||
register_replacement)
|
||||
from torch._inductor.pattern_matcher import (
|
||||
PatternMatcherPass,
|
||||
fwd_only,
|
||||
register_replacement,
|
||||
)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
@@ -29,11 +36,11 @@ SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"))
|
||||
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
)
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[
|
||||
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
@@ -49,16 +56,18 @@ class ActivationQuantPattern(ABC):
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, \
|
||||
assert self.quant_key in FUSED_OPS, (
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
)
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
@@ -72,37 +81,40 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
|
||||
def __init__(self, symmetric: bool = True):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(quant_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale
|
||||
)
|
||||
return at2[1]
|
||||
|
||||
def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
scale=scale)
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
self.empty_quant(5, 4), # result
|
||||
empty_bf16(5, 4), # result_silu_mul
|
||||
empty_bf16(5, 4), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
@@ -117,28 +129,37 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
super().__init__(kNvfp4Quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(SILU_MUL_OP,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=result,
|
||||
input=at1[1],
|
||||
output_scale=output_scale,
|
||||
input_scale=scale)
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=at1[1],
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
)
|
||||
return at2[1], at2[2]
|
||||
|
||||
def replacement(result: torch.Tensor, output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale)
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
@@ -146,7 +167,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
empty_i32(128, 4), # output_scale
|
||||
empty_bf16(5, 64), # result_silu_mul
|
||||
empty_bf16(5, 64), # input
|
||||
empty_fp32(1, 1) # scale
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
@@ -167,7 +188,8 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass")
|
||||
pass_name="activation_quant_fusion_pass"
|
||||
)
|
||||
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
@@ -184,6 +206,9 @@ class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern)
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern,
|
||||
)
|
||||
|
||||
@@ -20,8 +20,12 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
|
||||
|
||||
from .compiler_interface import (CompilerInterface, EagerAdaptor,
|
||||
InductorAdaptor, InductorStandaloneAdaptor)
|
||||
from .compiler_interface import (
|
||||
CompilerInterface,
|
||||
EagerAdaptor,
|
||||
InductorAdaptor,
|
||||
InductorStandaloneAdaptor,
|
||||
)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .pass_manager import PostGradPassManager
|
||||
@@ -33,9 +37,11 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")):
|
||||
if (
|
||||
envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")
|
||||
):
|
||||
logger.debug("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor()
|
||||
else:
|
||||
@@ -70,10 +76,9 @@ class CompilerManager:
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
"""
|
||||
Initialize the cache directory for the compiler.
|
||||
|
||||
@@ -101,9 +106,9 @@ class CompilerManager:
|
||||
# do not use eval(), it is unsafe.
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
|
||||
self.compiler.initialize_cache(cache_dir=cache_dir,
|
||||
disable_cache=disable_cache,
|
||||
prefix=prefix)
|
||||
self.compiler.initialize_cache(
|
||||
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
||||
)
|
||||
|
||||
def save_to_file(self):
|
||||
if self.disable_cache or not self.is_cache_updated:
|
||||
@@ -113,35 +118,46 @@ class CompilerManager:
|
||||
with open(self.cache_file_path, "w") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
||||
def load(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Optional[Callable]:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(handle, graph, example_inputs,
|
||||
graph_index, runtime_shape)
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, graph_index, runtime_shape
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for dynamic shape from %s via "
|
||||
"handle %s", graph_index, self.compiler.name, handle)
|
||||
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for shape %s from %s via "
|
||||
"handle %s", graph_index, str(runtime_shape),
|
||||
self.compiler.name, handle)
|
||||
"Directly load the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
def compile(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None) -> Any:
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
graph_index: int = 0,
|
||||
num_graphs: int = 1,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Any:
|
||||
if graph_index == 0:
|
||||
# before compiling the first graph, record the start time
|
||||
global compilation_start_time
|
||||
@@ -152,8 +168,7 @@ class CompilerManager:
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index,
|
||||
runtime_shape)
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
@@ -163,12 +178,16 @@ class CompilerManager:
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for dynamic shape "
|
||||
"from the cache, took %.3f s", elapsed)
|
||||
"from the cache, took %.3f s",
|
||||
elapsed,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Directly load the compiled graph(s) for shape %s "
|
||||
"from the cache, took %.3f s", str(runtime_shape),
|
||||
elapsed)
|
||||
"from the cache, took %.3f s",
|
||||
str(runtime_shape),
|
||||
elapsed,
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
@@ -177,37 +196,41 @@ class CompilerManager:
|
||||
# Let compile_fx generate a key for us
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = \
|
||||
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape,
|
||||
maybe_key)
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key
|
||||
)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
# store the artifact in the cache
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
|
||||
self.cache[(runtime_shape, graph_index,
|
||||
self.compiler.name)] = handle
|
||||
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
# adds some info logging for the first graph
|
||||
if runtime_shape is None:
|
||||
logger.info(
|
||||
"Cache the graph for dynamic shape for later use")
|
||||
logger.info("Cache the graph for dynamic shape for later use")
|
||||
else:
|
||||
logger.info("Cache the graph of shape %s for later use",
|
||||
str(runtime_shape))
|
||||
logger.info(
|
||||
"Cache the graph of shape %s for later use", str(runtime_shape)
|
||||
)
|
||||
if runtime_shape is None:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for dynamic shape from %s via "
|
||||
"handle %s", graph_index, self.compiler.name, handle)
|
||||
"Store the %s-th graph for dynamic shape from %s via handle %s",
|
||||
graph_index,
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Store the %s-th graph for shape %s from %s via handle %s",
|
||||
graph_index, str(runtime_shape), self.compiler.name,
|
||||
handle)
|
||||
graph_index,
|
||||
str(runtime_shape),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
)
|
||||
|
||||
# after compiling the last graph, record the end time
|
||||
if graph_index == num_graphs - 1:
|
||||
@@ -215,11 +238,13 @@ class CompilerManager:
|
||||
elapsed = now - compilation_start_time
|
||||
compilation_config.compilation_time += elapsed
|
||||
if runtime_shape is None:
|
||||
logger.info("Compiling a graph for dynamic shape takes %.2f s",
|
||||
elapsed)
|
||||
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
|
||||
else:
|
||||
logger.info("Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape, elapsed)
|
||||
logger.info(
|
||||
"Compiling a graph for shape %s takes %.2f s",
|
||||
runtime_shape,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
@@ -232,8 +257,9 @@ class SplitItem:
|
||||
graph: fx.GraphModule
|
||||
|
||||
|
||||
def split_graph(graph: fx.GraphModule,
|
||||
ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
def split_graph(
|
||||
graph: fx.GraphModule, ops: list[str]
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id = {}
|
||||
@@ -241,7 +267,7 @@ def split_graph(graph: fx.GraphModule,
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
if node.op == 'call_function' and str(node.target) in ops:
|
||||
if node.op == "call_function" and str(node.target) in ops:
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
@@ -254,10 +280,8 @@ def split_graph(graph: fx.GraphModule,
|
||||
# the semantics of the graph will change when we
|
||||
# have mutations in the graph
|
||||
split_gm = torch.fx.passes.split_module.split_module(
|
||||
graph,
|
||||
None,
|
||||
lambda node: node_to_subgraph_id[node],
|
||||
keep_original_order=True)
|
||||
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
@@ -271,8 +295,7 @@ def split_graph(graph: fx.GraphModule,
|
||||
module = getattr(split_gm, name)
|
||||
|
||||
graph_id = int(name.replace("submod_", ""))
|
||||
outputs.append(
|
||||
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
|
||||
# sort by integer graph_id, rather than string name
|
||||
outputs.sort(key=lambda x: x.graph_id)
|
||||
@@ -295,11 +318,16 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
has some special cudagraph output handling.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str], vllm_config: VllmConfig,
|
||||
vllm_backend: "VllmBackend"):
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
vllm_backend: "VllmBackend",
|
||||
):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
@@ -316,9 +344,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
with self.fake_mode, enable_python_dispatcher():
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(self, target: torch.fx.node.Target,
|
||||
args: tuple[torch.fx.node.Argument,
|
||||
...], kwargs: dict[str, Any]) -> Any:
|
||||
def call_module(
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[torch.fx.node.Argument, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
@@ -330,26 +361,34 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
]
|
||||
global compilation_start_time
|
||||
|
||||
compiled_graph_for_dynamic_shape = self.vllm_backend.\
|
||||
compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None)
|
||||
compiled_graph_for_dynamic_shape = (
|
||||
self.vllm_backend.compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None,
|
||||
)
|
||||
)
|
||||
# Lazy import here to avoid circular import
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod, self.vllm_config, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||
submod,
|
||||
self.vllm_config,
|
||||
index,
|
||||
len(self.compile_submod_names),
|
||||
sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape,
|
||||
self.vllm_backend,
|
||||
)
|
||||
|
||||
if (self.compilation_config.cudagraph_mode.\
|
||||
has_piecewise_cudagraphs() and
|
||||
not self.compilation_config.use_inductor_graph_partition):
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and not self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
from .cuda_graph import CUDAGraphOptions
|
||||
@@ -357,7 +396,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
@@ -370,7 +410,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
weak_ref_output=piecewise_backend.is_last_graph))
|
||||
weak_ref_output=piecewise_backend.is_last_graph,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.module.__dict__[target] = piecewise_backend
|
||||
|
||||
@@ -388,8 +430,9 @@ model_tag: str = "backbone"
|
||||
def set_model_tag(tag: str):
|
||||
"""Context manager to set the model tag."""
|
||||
global model_tag
|
||||
assert tag != model_tag, \
|
||||
assert tag != model_tag, (
|
||||
f"Model tag {tag} is the same as the current tag {model_tag}."
|
||||
)
|
||||
old_tag = model_tag
|
||||
model_tag = tag
|
||||
try:
|
||||
@@ -430,7 +473,6 @@ class VllmBackend:
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
|
||||
# if the model is initialized with a non-empty prefix,
|
||||
# then usually it's enough to use that prefix,
|
||||
# e.g. language_model, vision_model, etc.
|
||||
@@ -449,7 +491,8 @@ class VllmBackend:
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.compiler_manager: CompilerManager = CompilerManager(
|
||||
self.compilation_config)
|
||||
self.compilation_config
|
||||
)
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
@@ -465,8 +508,10 @@ class VllmBackend:
|
||||
if PASS_KEY in inductor_config:
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
# PassManager already added to config, make sure it's correct
|
||||
assert (inductor_config[PASS_KEY].uuid() ==
|
||||
self.post_grad_pass_manager.uuid())
|
||||
assert (
|
||||
inductor_config[PASS_KEY].uuid()
|
||||
== self.post_grad_pass_manager.uuid()
|
||||
)
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
@@ -474,7 +519,6 @@ class VllmBackend:
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
if not self.compilation_config.cache_dir:
|
||||
# no provided cache dir, generate one based on the known factors
|
||||
@@ -495,12 +539,12 @@ class VllmBackend:
|
||||
|
||||
# 2. factors come from the code files that are traced by Dynamo (
|
||||
# it mainly summarizes how the model is used in forward pass)
|
||||
forward_code_files = list(
|
||||
sorted(self.compilation_config.traced_files))
|
||||
forward_code_files = list(sorted(self.compilation_config.traced_files))
|
||||
self.compilation_config.traced_files.clear()
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s",
|
||||
"\n".join(forward_code_files))
|
||||
"\n".join(forward_code_files),
|
||||
)
|
||||
hash_content = []
|
||||
for filepath in forward_code_files:
|
||||
hash_content.append(filepath)
|
||||
@@ -511,8 +555,10 @@ class VllmBackend:
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
import hashlib
|
||||
code_hash = hashlib.md5("\n".join(hash_content).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
|
||||
code_hash = hashlib.md5(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
@@ -520,8 +566,9 @@ class VllmBackend:
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
hash_key = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
@@ -535,8 +582,7 @@ class VllmBackend:
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
rank = vllm_config.parallel_config.rank
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
|
||||
self.prefix)
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
|
||||
os.makedirs(local_cache_dir, exist_ok=True)
|
||||
self.compilation_config.local_cache_dir = local_cache_dir
|
||||
|
||||
@@ -545,16 +591,19 @@ class VllmBackend:
|
||||
if disable_cache:
|
||||
logger.info("vLLM's torch.compile cache is disabled.")
|
||||
else:
|
||||
logger.info("Using cache directory: %s for vLLM's torch.compile",
|
||||
local_cache_dir)
|
||||
logger.info(
|
||||
"Using cache directory: %s for vLLM's torch.compile", local_cache_dir
|
||||
)
|
||||
|
||||
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
|
||||
self.prefix)
|
||||
self.compiler_manager.initialize_cache(
|
||||
local_cache_dir, disable_cache, self.prefix
|
||||
)
|
||||
|
||||
# when dynamo calls the backend, it means the bytecode
|
||||
# transform and analysis are done
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
from .monitor import torch_compile_start_time
|
||||
|
||||
dynamo_time = time.time() - torch_compile_start_time
|
||||
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
|
||||
self.compilation_config.compilation_time += dynamo_time
|
||||
@@ -567,7 +616,8 @@ class VllmBackend:
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_config.splitting_ops)
|
||||
graph, self.compilation_config.splitting_ops
|
||||
)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
@@ -576,25 +626,27 @@ class VllmBackend:
|
||||
lazy_format_graph_code("before split", self.graph)
|
||||
lazy_format_graph_code("after split", self.split_gm)
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(
|
||||
self.piecewise_graphs)
|
||||
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name for item in self.piecewise_graphs
|
||||
item.submod_name
|
||||
for item in self.piecewise_graphs
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.vllm_config,
|
||||
self).run(*example_inputs)
|
||||
PiecewiseCompileInterpreter(
|
||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||
).run(*example_inputs)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
||||
# use `print_readable` because it can include submodules
|
||||
src = "from __future__ import annotations\nimport torch\n" + \
|
||||
self.split_gm.print_readable(print_output=False)
|
||||
src = (
|
||||
"from __future__ import annotations\nimport torch\n"
|
||||
+ self.split_gm.print_readable(print_output=False)
|
||||
)
|
||||
src = src.replace("<lambda>", "GraphModule")
|
||||
with open(graph_path, "w") as f:
|
||||
f.write(src)
|
||||
@@ -603,12 +655,15 @@ class VllmBackend:
|
||||
|
||||
self._called = True
|
||||
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
|
||||
not self.compilation_config.cudagraph_copy_inputs:
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
or not self.compilation_config.cudagraph_copy_inputs
|
||||
):
|
||||
return self.split_gm
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
fake_args = [
|
||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
@@ -619,10 +674,12 @@ class VllmBackend:
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
|
||||
self.sym_tensor_indices = [
|
||||
i for i, x in enumerate(fake_args)
|
||||
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
|
||||
any(is_symbolic(d) for d in x.size())
|
||||
i
|
||||
for i, x in enumerate(fake_args)
|
||||
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and any(is_symbolic(d) for d in x.size())
|
||||
]
|
||||
|
||||
# compiler managed cudagraph input buffers
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,14 +24,14 @@ class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
@@ -93,12 +93,14 @@ class CompilerInterface:
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
@@ -150,11 +152,13 @@ def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
@@ -169,18 +173,19 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
|
||||
name = "inductor_standalone"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
@@ -203,12 +208,14 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config})
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
@@ -218,19 +225,23 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
compilation_counter.num_compiled_artifacts_saved += 1
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format="unpacked")
|
||||
path=path, format="unpacked"
|
||||
)
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
@@ -250,21 +261,22 @@ class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a sub-directory
|
||||
@@ -288,6 +300,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
@@ -308,8 +321,8 @@ class InductorAdaptor(CompilerInterface):
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import (FxGraphCache,
|
||||
compiled_fx_graph_hash)
|
||||
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
@@ -326,7 +339,8 @@ class InductorAdaptor(CompilerInterface):
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
self.base_cache_dir
|
||||
):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
break
|
||||
@@ -338,8 +352,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(
|
||||
*args, **kwargs)
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
@@ -353,8 +366,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
@@ -387,29 +399,38 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash))
|
||||
patch(
|
||||
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash,
|
||||
)
|
||||
)
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache)
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env))
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache))
|
||||
_check_can_cache,
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
@@ -422,23 +443,26 @@ class InductorAdaptor(CompilerInterface):
|
||||
# standalone_compile sometime.
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False))
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False))
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(
|
||||
enable_remote_autograd_cache=False))
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config)
|
||||
config_patches=current_config,
|
||||
)
|
||||
|
||||
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
||||
# compilation cache. So turn off the checks if we disable the
|
||||
@@ -451,52 +475,63 @@ class InductorAdaptor(CompilerInterface):
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. ")
|
||||
"to see the real issue. "
|
||||
)
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph")
|
||||
"failed to get the file path of the compiled graph"
|
||||
)
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache)
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv()))
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False)
|
||||
hash_str, example_inputs, True, False
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
elif torch.__version__ >= "2.6":
|
||||
from torch._inductor.output_code import (
|
||||
CompiledFxGraphConstantsWithGm)
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants)
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove"
|
||||
f"the cache directory and try again." # noqa
|
||||
@@ -509,6 +544,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
@@ -542,6 +578,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
@@ -553,7 +590,8 @@ def set_inductor_config(config, runtime_shape):
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
|
||||
)
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
|
||||
@@ -41,7 +41,8 @@ class CompilationCounter:
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}")
|
||||
f"expected diff is {v}"
|
||||
)
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
|
||||
@@ -12,8 +12,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -46,10 +45,10 @@ class CUDAGraphWrapper:
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
@@ -58,18 +57,20 @@ class CUDAGraphWrapper:
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None,
|
||||
):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
@@ -91,15 +92,16 @@ class CUDAGraphWrapper:
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
|
||||
= {}
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}")
|
||||
raise AttributeError(
|
||||
f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}"
|
||||
)
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
@@ -110,8 +112,10 @@ class CUDAGraphWrapper:
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
cudagraph_runtime_mode != self.runtime_mode:
|
||||
if (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||
or cudagraph_runtime_mode != self.runtime_mode
|
||||
):
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
@@ -122,8 +126,9 @@ class CUDAGraphWrapper:
|
||||
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = \
|
||||
CUDAGraphEntry(batch_descriptor=batch_descriptor)
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
|
||||
batch_descriptor=batch_descriptor
|
||||
)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
@@ -133,8 +138,11 @@ class CUDAGraphWrapper:
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
logger.debug(
|
||||
"Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_descriptor,
|
||||
)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
@@ -153,8 +161,7 @@ class CUDAGraphWrapper:
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
@@ -193,7 +200,8 @@ class CUDAGraphWrapper:
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
|
||||
@@ -34,11 +34,11 @@ def ignore_torch_compile(cls: _T) -> _T:
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
@@ -58,21 +58,18 @@ def _should_ignore_torch_compile(cls) -> bool:
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: _T) -> _T:
|
||||
...
|
||||
def support_torch_compile(cls: _T) -> _T: ...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
@@ -89,8 +86,7 @@ def support_torch_compile(
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
@@ -98,8 +94,7 @@ def support_torch_compile(
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||
...
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
@@ -139,7 +134,7 @@ def support_torch_compile(
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||
# to avoid too much indentation for `_support_torch_compile``
|
||||
if not hasattr(cls, 'forward'):
|
||||
if not hasattr(cls, "forward"):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
@@ -147,26 +142,31 @@ def support_torch_compile(
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor, Optional[torch.Tensor],
|
||||
IntermediateTensors, Optional[IntermediateTensors]
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
IntermediateTensors,
|
||||
Optional[IntermediateTensors],
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(("Inferred dynamic dimensions for "
|
||||
"forward method of %s: %s"), cls,
|
||||
list(inferred_dynamic_arg_dims.keys()))
|
||||
logger.debug(
|
||||
("Inferred dynamic dimensions for forward method of %s: %s"),
|
||||
cls,
|
||||
list(inferred_dynamic_arg_dims.keys()),
|
||||
)
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly.")
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly."
|
||||
)
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}")
|
||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
|
||||
enable_if)
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
@@ -191,29 +191,32 @@ def _support_torch_compile(
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.vllm_config = vllm_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = \
|
||||
vllm_config.compilation_config.level in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo() or _should_ignore_torch_compile(
|
||||
self.__class__) or not enable_compile
|
||||
self.do_not_compile = (
|
||||
vllm_config.compilation_config.level
|
||||
in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS]
|
||||
or not supports_dynamo()
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
or not enable_compile
|
||||
)
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_level=vllm_config.compilation_config.level)
|
||||
self, compilation_level=vllm_config.compilation_config.level
|
||||
)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
@@ -235,26 +238,23 @@ def _support_torch_compile(
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
arg.ndim + dim if dim < 0 else dim for dim in dims
|
||||
]
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
tensor.ndim + dim if dim < 0 else dim
|
||||
for dim in dims
|
||||
tensor.ndim + dim if dim < 0 else dim for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}.")
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
logger.debug("Start compiling function %s",
|
||||
self.original_code_object)
|
||||
logger.debug("Start compiling function %s", self.original_code_object)
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
@@ -263,8 +263,7 @@ def _support_torch_compile(
|
||||
# it seems Dynamo reuse the compilation across instances,
|
||||
# while we need to make sure the compiled code is not reused.
|
||||
# we need to control all the compilation of the model.
|
||||
torch._dynamo.eval_frame.remove_from_cache(
|
||||
self.original_code_object)
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
@@ -272,7 +271,8 @@ def _support_torch_compile(
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
self.original_code_object.co_filename)
|
||||
self.original_code_object.co_filename
|
||||
)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call
|
||||
@@ -282,8 +282,7 @@ def _support_torch_compile(
|
||||
|
||||
def patched_inline_call(parent, func, args, kwargs):
|
||||
code = func.get_code()
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
code.co_filename)
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(parent, func, args, kwargs)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
@@ -293,20 +292,20 @@ def _support_torch_compile(
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches[
|
||||
"enable_cpp_symbolic_shape_guards"] = False
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug(
|
||||
"enable_cpp_symbolic_shape_guards config not available")
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
with patch.object(
|
||||
InliningInstructionTranslator, "inline_call",
|
||||
patched_inline_call), torch._dynamo.config.patch(
|
||||
**dynamo_config_patches
|
||||
), maybe_use_cudagraph_partition_wrapper(
|
||||
self.vllm_config), _torch27_patch_tensor_subclasses():
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
return output
|
||||
|
||||
@@ -336,18 +335,20 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
def customized_cudagraph_wrapper(f,
|
||||
metadata: CUDAGraphWrapperMetadata):
|
||||
def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
@@ -358,15 +359,19 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
debug_log_enable=partition_id == 0,
|
||||
gc_disable=partition_id != 0,
|
||||
weak_ref_output=partition_id == num_partitions - 1,
|
||||
))
|
||||
),
|
||||
)
|
||||
|
||||
torch._inductor.utils.set_customized_partition_wrappers(
|
||||
customized_cudagraph_wrapper)
|
||||
customized_cudagraph_wrapper
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@@ -378,23 +383,32 @@ def _torch27_patch_tensor_subclasses():
|
||||
`BasevLLMParameters` without having to replace them with regular tensors
|
||||
before `torch.compile`-time.
|
||||
"""
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(
|
||||
torch.__version__) < version.parse("2.8"):
|
||||
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
|
||||
yield
|
||||
return
|
||||
|
||||
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
|
||||
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
|
||||
RowvLLMParameter
|
||||
]),
|
||||
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
|
||||
return_false)):
|
||||
with (
|
||||
torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses",
|
||||
[
|
||||
BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
_ColumnvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
@@ -31,8 +31,9 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug("XPU platform does not support fix functionalization"
|
||||
"pass currently.")
|
||||
logger.debug(
|
||||
"XPU platform does not support fix functionalizationpass currently."
|
||||
)
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
@@ -45,19 +46,21 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs['query']
|
||||
key = kwargs['key']
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
if (is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0],
|
||||
torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users)):
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of an mm_node.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into mm_node. So after de-functionalization, we can
|
||||
@@ -66,8 +69,9 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
mm_node = query.args[0].args[0]
|
||||
for user in getitem_nodes.values():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(user_of_getitem,
|
||||
torch.ops.aten.slice_scatter.default):
|
||||
if is_func(
|
||||
user_of_getitem, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
@@ -81,49 +85,54 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
# do this blindly, but in practice in vLLM it's ok. The best
|
||||
# solution is to use auto_functionalization_v2 and then use
|
||||
# inductor's builtin defunctionalization (reinplacing) pass.
|
||||
mutated_args = {1: 'query', 2: 'key'}
|
||||
mutated_args = {1: "query", 2: "key"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: 'input', 2: 'residual'}
|
||||
mutated_args = {1: "input", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: 'result', 2: 'residual'}
|
||||
mutated_args = {1: "result", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
|
||||
mutated_args = {1: "result", 2: "scale", 3: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: 'result'}
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input'))
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input")
|
||||
)
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input', 'scale'))
|
||||
elif hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
|
||||
mutated_args = {1: 'result', 2: 'result_block_scale'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'result_block_scale',
|
||||
'input', 'input_global_scale'))
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input", "scale")
|
||||
)
|
||||
elif (
|
||||
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
|
||||
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
|
||||
):
|
||||
mutated_args = {1: "result", 2: "result_block_scale"}
|
||||
self.defunctionalize(
|
||||
graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=(
|
||||
"result",
|
||||
"result_block_scale",
|
||||
"input",
|
||||
"input_global_scale",
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
@@ -136,12 +145,12 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
||||
count_removed)
|
||||
logger.debug(
|
||||
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||
)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
||||
Iterable[torch.fx.Node]]):
|
||||
def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]):
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
@@ -150,12 +159,13 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
def defunctionalize(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
||||
):
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
@@ -165,10 +175,9 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
||||
mutated_args: dict[int,
|
||||
Union[torch.fx.Node,
|
||||
str]]):
|
||||
def replace_users_with_mutated_args(
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]]
|
||||
):
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
@@ -194,11 +203,12 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
def insert_defunctionalized(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None,
|
||||
):
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
@@ -210,8 +220,9 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), \
|
||||
assert is_func(node, auto_functionalized), (
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
)
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
@@ -220,6 +231,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
for arg in args)
|
||||
args = tuple(
|
||||
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||
)
|
||||
graph.call_function(function, args=args)
|
||||
|
||||
@@ -12,8 +12,15 @@ from torch._ops import OpOverload
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
@@ -40,12 +47,9 @@ RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym:
|
||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym:
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||
@@ -57,80 +61,93 @@ class FusedRMSQuantKey(NamedTuple):
|
||||
quant: type of quantization
|
||||
fused_add: does the op also perform the residual add
|
||||
"""
|
||||
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self):
|
||||
return (f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)")
|
||||
return (
|
||||
f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)"
|
||||
)
|
||||
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, False
|
||||
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, True
|
||||
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, False
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, True
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
assert key.quant in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {key.quant}"
|
||||
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
|
||||
self.QUANT_OP = QUANT_OPS[key.quant]
|
||||
|
||||
assert key in FUSED_OPS, \
|
||||
f"unsupported fused rmsnorm+quant op for {key}"
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale
|
||||
)
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
|
||||
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon)
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
@@ -140,53 +157,60 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
|
||||
pm_pass)
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at1 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale)
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale
|
||||
)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
|
||||
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon)
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
@@ -196,7 +220,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
@@ -209,49 +233,59 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale,
|
||||
scale_ub=None)
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at2[1], at2[2]
|
||||
|
||||
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None)
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
@@ -261,7 +295,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
@@ -274,49 +308,59 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True,
|
||||
):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon)
|
||||
at1 = auto_functionalized(self.QUANT_OP,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale,
|
||||
scale_ub=None)
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at1[1], at[2], at1[2]
|
||||
|
||||
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual)
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
@@ -326,7 +370,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1) # scale
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
@@ -349,24 +393,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns)
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon,
|
||||
FP8_DTYPE).register(self.patterns)
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns)
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@@ -376,8 +421,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self.hash_source(self, RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern)
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,10 @@ from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kNvfp4Quant, kStaticTensorScale)
|
||||
QuantKey,
|
||||
kNvfp4Quant,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
@@ -49,21 +52,21 @@ class AttentionQuantPattern(ABC):
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs}
|
||||
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
|
||||
@@ -72,6 +75,7 @@ class AttentionQuantPattern(ABC):
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
@@ -100,70 +104,85 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=output_quant,
|
||||
input=attn_out_view,
|
||||
scale=scale)
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale
|
||||
)
|
||||
return at2[1]
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None)
|
||||
device=q.device,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # q
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # k
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # v
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # attn_output
|
||||
self.empty_quant(5,
|
||||
self.num_heads * self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v
|
||||
self.empty(
|
||||
5, self.num_heads, self.head_size, dtype=self.dtype
|
||||
), # attn_output
|
||||
self.empty_quant(5, self.num_heads * self.head_size), # quant_output
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
@@ -180,50 +199,67 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
super().__init__(layer, kNvfp4Quant, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale)
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
device=q.device,
|
||||
)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(
|
||||
output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view)
|
||||
output = RESHAPE_OP(at2[1],
|
||||
[-1, self.num_heads * self.head_size // 2])
|
||||
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
)
|
||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
@@ -231,18 +267,22 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size //
|
||||
2), # output_quant
|
||||
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
|
||||
4)), # output_scale
|
||||
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
|
||||
empty_i32(
|
||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||
), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
@@ -267,20 +307,22 @@ class AttnFusionPass(VllmPatternMatcherPass):
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C,
|
||||
"scaled_fp4_quant"):
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered.")
|
||||
"so no fusion patterns were registered."
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@@ -290,6 +332,9 @@ class AttnFusionPass(VllmPatternMatcherPass):
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern)
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern,
|
||||
)
|
||||
|
||||
@@ -19,8 +19,9 @@ def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
|
||||
|
||||
# Returns the first specified node with the given op (if it exists)
|
||||
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
def find_specified_fn_maybe(
|
||||
nodes: Iterable[fx.Node], op: OpOverload
|
||||
) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if node.target == op:
|
||||
return node
|
||||
@@ -35,8 +36,7 @@ def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
|
||||
@@ -11,8 +11,7 @@ from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
@@ -21,13 +20,13 @@ if is_torch_equal_or_newer("2.6"):
|
||||
else:
|
||||
# CustomGraphPass is not present in 2.5 or lower, import our version
|
||||
from .torch25_custom_graph_pass import ( # noqa: E501
|
||||
Torch25CustomGraphPass as CustomGraphPass)
|
||||
Torch25CustomGraphPass as CustomGraphPass,
|
||||
)
|
||||
|
||||
_pass_context = None
|
||||
|
||||
|
||||
class PassContext:
|
||||
|
||||
def __init__(self, runtime_shape: Optional[int]):
|
||||
self.runtime_shape = runtime_shape
|
||||
|
||||
@@ -106,9 +105,9 @@ class CallableInductorPass(InductorPass):
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
callable: Callable[[fx.Graph], None],
|
||||
uuid: Optional[Any] = None):
|
||||
def __init__(
|
||||
self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None
|
||||
):
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
@@ -127,8 +126,7 @@ def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args, **kwargs) -> Any:
|
||||
with torch._guards.tracing(
|
||||
None), unset_fake_temporarily(), FakeTensorMode():
|
||||
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
@@ -20,6 +20,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and path:
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
@@ -29,8 +30,9 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
logger.info("torch.compile takes %.2f s in total",
|
||||
compilation_config.compilation_time)
|
||||
logger.info(
|
||||
"torch.compile takes %.2f s in total", compilation_config.compilation_time
|
||||
)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
@@ -46,8 +48,10 @@ def validate_cudagraph_capturing_enabled():
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled.")
|
||||
raise RuntimeError(
|
||||
"CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled."
|
||||
)
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool):
|
||||
|
||||
@@ -122,8 +122,9 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Reshape helpers ----------------------
|
||||
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
def reshape_dims_equivalent(
|
||||
self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt]
|
||||
) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
@@ -153,6 +154,4 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
dims: Iterable[Union[int, torch.fx.Node]],
|
||||
i_dims: Iterable[Union[int, SymInt]],
|
||||
) -> bool:
|
||||
return all(
|
||||
self.reshape_dims_equivalent(s, i_s)
|
||||
for s, i_s in zip(dims, i_dims))
|
||||
return all(self.reshape_dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
|
||||
@@ -23,15 +23,19 @@ class ConcreteSizeEntry:
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend,
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
@@ -46,13 +50,11 @@ class PiecewiseBackend:
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
|
||||
self.compile_sizes: set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
@@ -108,7 +110,8 @@ class PiecewiseBackend:
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape)
|
||||
runtime_shape=runtime_shape,
|
||||
)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
|
||||
@@ -16,5 +16,6 @@ class PostCleanupPass(VllmInductorPass):
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
|
||||
@@ -9,8 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -23,12 +22,14 @@ logger = init_logger(__name__)
|
||||
class _RMSNormAndQuantOpHelper:
|
||||
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
@@ -40,60 +41,78 @@ class _RMSNormAndQuantOpHelper:
|
||||
result=result_buffer,
|
||||
input=input_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor,
|
||||
weight_tensor):
|
||||
def _functional_fused_add_rmsnorm(
|
||||
self, input_tensor, residual_tensor, weight_tensor
|
||||
):
|
||||
return torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input_tensor,
|
||||
residual=residual_tensor,
|
||||
weight=weight_tensor,
|
||||
epsilon=self.epsilon)
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer,
|
||||
quant_result_buffer, input_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
def _functional_rmsnorm_then_quant(
|
||||
self,
|
||||
rmsnorm_result_buffer,
|
||||
quant_result_buffer,
|
||||
input_tensor,
|
||||
weight_tensor,
|
||||
scale_tensor,
|
||||
):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer,
|
||||
input_tensor,
|
||||
weight_tensor)
|
||||
rmsnorm_out_tuple = self._functional_rmsnorm(
|
||||
rmsnorm_result_buffer, input_tensor, weight_tensor
|
||||
)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
scale=scale_tensor,
|
||||
)
|
||||
return quant_out_tuple
|
||||
|
||||
def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer,
|
||||
input_tensor, residual_tensor,
|
||||
weight_tensor, scale_tensor):
|
||||
def _functional_fused_add_rmsnorm_then_quant(
|
||||
self,
|
||||
quant_result_buffer,
|
||||
input_tensor,
|
||||
residual_tensor,
|
||||
weight_tensor,
|
||||
scale_tensor,
|
||||
):
|
||||
if self.quant_op is None:
|
||||
raise RuntimeError(
|
||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
||||
)
|
||||
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
|
||||
input_tensor, residual_tensor, weight_tensor)
|
||||
input_tensor, residual_tensor, weight_tensor
|
||||
)
|
||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
||||
self.quant_op,
|
||||
result=quant_result_buffer,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale_tensor)
|
||||
scale=scale_tensor,
|
||||
)
|
||||
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
quant_op: Optional[torch._ops.OpOverload] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -103,21 +122,16 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp_group.unique_name)
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
@@ -126,7 +140,6 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
return [input, permute, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
permute: torch.Tensor,
|
||||
@@ -145,26 +158,23 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter,
|
||||
arg3_1)
|
||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
|
||||
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
@@ -173,7 +183,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
@@ -181,7 +190,8 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
all_reduce, residual, rms_norm_weights
|
||||
)
|
||||
return rmsnorm[1], rmsnorm[2]
|
||||
|
||||
def replacement(
|
||||
@@ -191,23 +201,22 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
reduce_scatter, residual, rms_norm_weights
|
||||
)
|
||||
all_gather = self._all_gather(rmsnorm[1])
|
||||
return all_gather, rmsnorm[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
@@ -216,7 +225,6 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
@@ -224,7 +232,8 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
all_reduce, residual, rms_norm_weights)
|
||||
all_reduce, residual, rms_norm_weights
|
||||
)
|
||||
return rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
@@ -234,37 +243,34 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
||||
reduce_scatter, residual, rms_norm_weights)
|
||||
reduce_scatter, residual, rms_norm_weights
|
||||
)
|
||||
normalized = self._all_gather(rmsnorm[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4],
|
||||
device=self.device,
|
||||
dtype=FP8_DTYPE)
|
||||
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
@@ -274,7 +280,8 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
):
|
||||
all_reduce = self._all_reduce(input)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, all_reduce, weight, scale)
|
||||
rmsnorm_result, quant_result, all_reduce, weight, scale
|
||||
)
|
||||
return static_fp8[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
@@ -286,34 +293,36 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
):
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm_result = torch.empty_like(reduce_scatter,
|
||||
dtype=rmsnorm_result.dtype)
|
||||
rmsnorm_result = torch.empty_like(
|
||||
reduce_scatter, dtype=rmsnorm_result.dtype
|
||||
)
|
||||
quant_result = torch.empty_like(
|
||||
rmsnorm_result, # Output of RMSNorm
|
||||
dtype=quant_result.dtype)
|
||||
dtype=quant_result.dtype,
|
||||
)
|
||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale)
|
||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale
|
||||
)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
@@ -326,7 +335,6 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
@@ -335,8 +343,11 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
static_fp8, rmsnorm_residual_out = (
|
||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
result, all_reduce, residual, rms_norm_weights, scale
|
||||
)
|
||||
)
|
||||
return static_fp8[1], rmsnorm_residual_out
|
||||
|
||||
def replacement(
|
||||
@@ -347,31 +358,31 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
||||
static_fp8, rmsnorm_residual_out = (
|
||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
||||
)
|
||||
)
|
||||
all_gather = self._all_gather(static_fp8[1])
|
||||
return all_gather, rmsnorm_residual_out
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
|
||||
op: torch._ops.OpOverload):
|
||||
def __init__(
|
||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
||||
):
|
||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
||||
|
||||
def get_inputs(self):
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4],
|
||||
device=self.device,
|
||||
dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
@@ -384,7 +395,6 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
@@ -394,7 +404,8 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
result, all_reduce, residual, rms_norm_weights, scale)
|
||||
result, all_reduce, residual, rms_norm_weights, scale
|
||||
)
|
||||
return static_fp8[1]
|
||||
|
||||
def replacement(
|
||||
@@ -405,16 +416,16 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter,
|
||||
dtype=result.dtype)
|
||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights,
|
||||
scale)
|
||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
||||
)
|
||||
normalized = self._all_gather(static_fp8[1])
|
||||
return normalized
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
@@ -442,30 +453,34 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass")
|
||||
pass_name="sequence_parallelism_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
).register(self.patterns)
|
||||
LastAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device,
|
||||
fp8_quant_op).register(self.patterns)
|
||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
||||
).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
FirstAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
MiddleAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
LastAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
|
||||
@@ -37,6 +37,8 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition)
|
||||
return self.uuid()
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise ValueError("Cannot unpickle CustomGraphPass because pickling"
|
||||
" is used for cache key uuid. Use torch>=2.6 with"
|
||||
" native uuid support for custom passes.")
|
||||
raise ValueError(
|
||||
"Cannot unpickle CustomGraphPass because pickling"
|
||||
" is used for cache key uuid. Use torch>=2.6 with"
|
||||
" native uuid support for custom passes."
|
||||
)
|
||||
|
||||
@@ -8,8 +8,7 @@ from typing import ClassVar, Optional
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
PatternPrettyPrinter)
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@@ -24,20 +23,18 @@ class VllmInductorPass(InductorPass):
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
|
||||
dump_prefix: ClassVar[Optional[int]] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config \
|
||||
else None
|
||||
self.device = config.device_config.device if config.device_config \
|
||||
else None
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(call_fn):
|
||||
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
@@ -51,8 +48,9 @@ class VllmInductorPass(InductorPass):
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
|
||||
graph.owning_module)
|
||||
lazy_format_graph_code(
|
||||
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
|
||||
)
|
||||
|
||||
def begin(self):
|
||||
self._start_time = time.perf_counter_ns()
|
||||
@@ -71,11 +69,13 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
|
||||
)
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
@@ -102,19 +102,22 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils import unique_filepath
|
||||
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
|
||||
)
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f'# This file was produced by VllmPatternMatcherPass.'
|
||||
f'dump_patterns for {self.pass_name}.\n'
|
||||
f'# It does its best to produce valid-Python-looking code but'
|
||||
f' please add to dump_patterns if there are any errors.\n\n'
|
||||
f'from torch._higher_order_ops.auto_functionalize import '
|
||||
f'auto_functionalized as auto_functionalized\n'
|
||||
f'from torch._inductor.pattern_matcher import *',
|
||||
file=f)
|
||||
f"# This file was produced by VllmPatternMatcherPass."
|
||||
f"dump_patterns for {self.pass_name}.\n"
|
||||
f"# It does its best to produce valid-Python-looking code but"
|
||||
f" please add to dump_patterns if there are any errors.\n\n"
|
||||
f"from torch._higher_order_ops.auto_functionalize import "
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *",
|
||||
file=f,
|
||||
)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
@@ -133,18 +136,21 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
] + [f"return {out_node}"]).replace("\n", "\n ")
|
||||
pattern_repr = "\n".join(
|
||||
[f"def pattern_{i}():"]
|
||||
+ [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
]
|
||||
+ [f"return {out_node}"]
|
||||
).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, name: str, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
@@ -10,8 +10,7 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode,
|
||||
get_current_vllm_config)
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -30,10 +29,9 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
`torch.compile` over the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
compiled_callable: Optional[Callable] = None,
|
||||
compilation_level: int = 0):
|
||||
|
||||
def __init__(
|
||||
self, compiled_callable: Optional[Callable] = None, compilation_level: int = 0
|
||||
):
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
if compiled_callable is None:
|
||||
@@ -43,13 +41,13 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = None
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = get_current_vllm_config(
|
||||
).compilation_config.inductor_compile_config
|
||||
options = (
|
||||
get_current_vllm_config().compilation_config.inductor_compile_config
|
||||
)
|
||||
|
||||
compiled_callable = torch.compile(self.forward,
|
||||
fullgraph=True,
|
||||
backend=backend,
|
||||
options=options)
|
||||
compiled_callable = torch.compile(
|
||||
self.forward, fullgraph=True, backend=backend, options=options
|
||||
)
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
@@ -59,8 +57,9 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = \
|
||||
self.use_custom_dispatcher: bool = (
|
||||
compilation_level >= CompilationLevel.DYNAMO_ONCE
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
@@ -70,8 +69,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
...
|
||||
def forward(self, *args, **kwargs): ...
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
@@ -103,21 +101,27 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# but there's no 100% guarantee, since decompliation is
|
||||
# not a reversible process.
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
|
||||
with open(decompiled_file, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Dynamo transformed code saved to %s",
|
||||
decompiled_file)
|
||||
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode != \
|
||||
CUDAGraphMode.NONE and "update" in new_code.co_names:
|
||||
if (
|
||||
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and "update" in new_code.co_names
|
||||
):
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
|
||||
msg = (
|
||||
"Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n"
|
||||
+ src
|
||||
) # noqa
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
@@ -129,7 +133,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa
|
||||
""" # noqa
|
||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||
yield
|
||||
self.__class__.forward.__code__ = self.original_code_object
|
||||
|
||||
Reference in New Issue
Block a user