Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -33,31 +33,31 @@ class CompilationLevel:
|
||||
|
||||
|
||||
class CUDAGraphMode(enum.Enum):
|
||||
""" Constants for the cudagraph mode in CompilationConfig.
|
||||
"""Constants for the cudagraph mode in CompilationConfig.
|
||||
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
|
||||
treated as concrete runtime mode for cudagraph runtime dispatching.
|
||||
"""
|
||||
|
||||
NONE = 0
|
||||
PIECEWISE = 1
|
||||
FULL = 2
|
||||
FULL_DECODE_ONLY = (FULL, NONE)
|
||||
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
|
||||
|
||||
def decode_mode(self) -> 'CUDAGraphMode':
|
||||
return CUDAGraphMode(self.value[0]) if \
|
||||
self.separate_routine() else self
|
||||
def decode_mode(self) -> "CUDAGraphMode":
|
||||
return CUDAGraphMode(self.value[0]) if self.separate_routine() else self
|
||||
|
||||
def mixed_mode(self) -> 'CUDAGraphMode':
|
||||
return CUDAGraphMode(self.value[1]) if \
|
||||
self.separate_routine() else self
|
||||
def mixed_mode(self) -> "CUDAGraphMode":
|
||||
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
|
||||
|
||||
def requires_piecewise_compilation(self) -> bool:
|
||||
return (self.decode_mode() == CUDAGraphMode.PIECEWISE
|
||||
or self.mixed_mode() == CUDAGraphMode.PIECEWISE)
|
||||
return (
|
||||
self.decode_mode() == CUDAGraphMode.PIECEWISE
|
||||
or self.mixed_mode() == CUDAGraphMode.PIECEWISE
|
||||
)
|
||||
|
||||
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
|
||||
return CUDAGraphMode(max(
|
||||
self.value)) if self.separate_routine() else self
|
||||
def max_cudagraph_mode(self) -> "CUDAGraphMode":
|
||||
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
|
||||
|
||||
def has_full_cudagraphs(self) -> bool:
|
||||
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
|
||||
@@ -69,9 +69,7 @@ class CUDAGraphMode(enum.Enum):
|
||||
return isinstance(self.value, tuple)
|
||||
|
||||
def valid_runtime_modes(self) -> bool:
|
||||
return self in [
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
]
|
||||
return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@@ -116,11 +114,13 @@ class PassConfig:
|
||||
if self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm/SiluMul + quant (fp8) fusion might not work")
|
||||
"RMSNorm/SiluMul + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_attn_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Attention + quant (fp8) fusion might not work")
|
||||
"Attention + quant (fp8) fusion might not work"
|
||||
)
|
||||
|
||||
|
||||
@config
|
||||
@@ -163,6 +163,7 @@ class CompilationConfig:
|
||||
sufficient for most cases. It might be beneficial to compile for
|
||||
certain small batchsizes, where inductor is good at optimizing.
|
||||
"""
|
||||
|
||||
# Top-level Compilation control
|
||||
level: Optional[int] = None
|
||||
"""The level of compilation:
|
||||
@@ -340,26 +341,24 @@ class CompilationConfig:
|
||||
"""local cache dir for each rank"""
|
||||
bs_to_padded_graph_size: list[int] = field(
|
||||
default=None, # type: ignore
|
||||
init=False)
|
||||
init=False,
|
||||
)
|
||||
"""optimization:
|
||||
Intuitively, bs_to_padded_graph_size should be dict[int, int].
|
||||
since we know all keys are in a range [0, max_capture_size],
|
||||
we can optimize it to list[int] for better lookup performance."""
|
||||
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||
init=False)
|
||||
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
|
||||
"""custom ops that are enabled"""
|
||||
disabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||
init=False)
|
||||
disabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
|
||||
"""custom ops that are disabled"""
|
||||
traced_files: set[str] = field(default_factory=set, init=False)
|
||||
"""files that are traced for compilation"""
|
||||
compilation_time: float = field(default=0.0, init=False)
|
||||
"""time taken for compilation"""
|
||||
|
||||
static_forward_context: dict[str, Any] = field(default_factory=dict,
|
||||
init=False)
|
||||
static_forward_context: dict[str, Any] = field(default_factory=dict, init=False)
|
||||
"""Per-model forward context
|
||||
Map from layer name to layer objects that need to be accessed outside
|
||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||
@@ -421,9 +420,9 @@ class CompilationConfig:
|
||||
if pass_config_exclude:
|
||||
exclude["pass_config"] = pass_config_exclude
|
||||
|
||||
config = TypeAdapter(CompilationConfig).dump_python(self,
|
||||
exclude=exclude,
|
||||
exclude_unset=True)
|
||||
config = TypeAdapter(CompilationConfig).dump_python(
|
||||
self, exclude=exclude, exclude_unset=True
|
||||
)
|
||||
|
||||
return str(config)
|
||||
|
||||
@@ -453,16 +452,16 @@ class CompilationConfig:
|
||||
# https://github.com/vllm-project/vllm/issues/14703
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
KEY = 'enable_auto_functionalized_v2'
|
||||
KEY = "enable_auto_functionalized_v2"
|
||||
if KEY not in self.inductor_compile_config:
|
||||
self.inductor_compile_config[KEY] = False
|
||||
|
||||
for k, v in self.inductor_passes.items():
|
||||
if not isinstance(v, str):
|
||||
assert callable(v), (
|
||||
f"pass {k} should be callable or a qualified name")
|
||||
self.inductor_compile_config[k] = v if isinstance(
|
||||
v, InductorPass) else CallableInductorPass(v)
|
||||
assert callable(v), f"pass {k} should be callable or a qualified name"
|
||||
self.inductor_compile_config[k] = (
|
||||
v if isinstance(v, InductorPass) else CallableInductorPass(v)
|
||||
)
|
||||
continue
|
||||
|
||||
# resolve function from qualified name
|
||||
@@ -470,54 +469,68 @@ class CompilationConfig:
|
||||
module = ".".join(names[:-1])
|
||||
func_name = names[-1]
|
||||
func = __import__(module).__dict__[func_name]
|
||||
self.inductor_compile_config[k] = func if isinstance(
|
||||
func, InductorPass) else CallableInductorPass(func)
|
||||
self.inductor_compile_config[k] = (
|
||||
func if isinstance(func, InductorPass) else CallableInductorPass(func)
|
||||
)
|
||||
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
|
||||
# migrate the deprecated flags
|
||||
if not self.use_cudagraph:
|
||||
logger.warning("use_cudagraph is deprecated, use "
|
||||
"cudagraph_mode=NONE instead.")
|
||||
if self.cudagraph_mode is not None and \
|
||||
self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
logger.warning(
|
||||
"use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
|
||||
)
|
||||
if (
|
||||
self.cudagraph_mode is not None
|
||||
and self.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
raise ValueError(
|
||||
"use_cudagraph and cudagraph_mode are mutually"
|
||||
" exclusive, prefer cudagraph_mode since "
|
||||
"use_cudagraph is deprecated.")
|
||||
"use_cudagraph is deprecated."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
if self.full_cuda_graph:
|
||||
logger.warning("full_cuda_graph is deprecated, use "
|
||||
"cudagraph_mode=FULL instead.")
|
||||
if self.cudagraph_mode is not None and \
|
||||
not self.cudagraph_mode.has_full_cudagraphs():
|
||||
raise ValueError("full_cuda_graph and cudagraph_mode are "
|
||||
"mutually exclusive, prefer cudagraph_mode "
|
||||
"since full_cuda_graph is deprecated.")
|
||||
logger.warning(
|
||||
"full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
|
||||
)
|
||||
if (
|
||||
self.cudagraph_mode is not None
|
||||
and not self.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
raise ValueError(
|
||||
"full_cuda_graph and cudagraph_mode are "
|
||||
"mutually exclusive, prefer cudagraph_mode "
|
||||
"since full_cuda_graph is deprecated."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
if (self.use_inductor_graph_partition
|
||||
and not is_torch_equal_or_newer("2.9.0.dev")):
|
||||
raise ValueError("use_inductor_graph_partition is only "
|
||||
"supported with torch>=2.9.0.dev. Set "
|
||||
"use_inductor_graph_partition=False instead.")
|
||||
if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"
|
||||
):
|
||||
raise ValueError(
|
||||
"use_inductor_graph_partition is only "
|
||||
"supported with torch>=2.9.0.dev. Set "
|
||||
"use_inductor_graph_partition=False instead."
|
||||
)
|
||||
|
||||
for op in self.custom_ops:
|
||||
if op[0] not in {'+', '-'} and op not in {'all', 'none'}:
|
||||
raise ValueError(f"Invalid syntax '{op}' for custom op, "
|
||||
"must be 'all', 'none', '+op' or '-op' "
|
||||
"(where 'op' is the registered op name)")
|
||||
if op[0] not in {"+", "-"} and op not in {"all", "none"}:
|
||||
raise ValueError(
|
||||
f"Invalid syntax '{op}' for custom op, "
|
||||
"must be 'all', 'none', '+op' or '-op' "
|
||||
"(where 'op' is the registered op name)"
|
||||
)
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
|
||||
from torch._dynamo.backends.registry import list_backends
|
||||
|
||||
torch_backends = list_backends(exclude_tags=tuple())
|
||||
if self.level in [
|
||||
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
|
||||
]:
|
||||
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||
if self.backend == "":
|
||||
return "eager"
|
||||
if self.backend in torch_backends:
|
||||
@@ -529,10 +542,10 @@ class CompilationConfig:
|
||||
assert self.level == CompilationLevel.PIECEWISE
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
return VllmBackend(vllm_config)
|
||||
|
||||
def init_with_cudagraph_sizes(self,
|
||||
cudagraph_capture_sizes: list[int]) -> None:
|
||||
def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
|
||||
"""To complete the initialization of config,
|
||||
we need to know the cudagraph sizes."""
|
||||
|
||||
@@ -542,9 +555,14 @@ class CompilationConfig:
|
||||
# de-duplicate the sizes provided by the config
|
||||
dedup_sizes = list(set(self.cudagraph_capture_sizes))
|
||||
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
|
||||
logger.info(("cudagraph sizes specified by model runner"
|
||||
" %s is overridden by config %s"),
|
||||
cudagraph_capture_sizes, dedup_sizes)
|
||||
logger.info(
|
||||
(
|
||||
"cudagraph sizes specified by model runner"
|
||||
" %s is overridden by config %s"
|
||||
),
|
||||
cudagraph_capture_sizes,
|
||||
dedup_sizes,
|
||||
)
|
||||
self.cudagraph_capture_sizes = dedup_sizes
|
||||
|
||||
computed_compile_sizes = []
|
||||
@@ -553,9 +571,10 @@ class CompilationConfig:
|
||||
self.compile_sizes = list(set(self.compile_sizes))
|
||||
for x in self.compile_sizes:
|
||||
if isinstance(x, str):
|
||||
assert x == "cudagraph_capture_sizes", \
|
||||
"Unrecognized size type in compile_sizes, " \
|
||||
assert x == "cudagraph_capture_sizes", (
|
||||
"Unrecognized size type in compile_sizes, "
|
||||
f"expect 'cudagraph_capture_sizes', got {x}"
|
||||
)
|
||||
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
|
||||
else:
|
||||
assert isinstance(x, int)
|
||||
@@ -564,29 +583,29 @@ class CompilationConfig:
|
||||
|
||||
# sort to make sure cudagraph capture sizes are in descending order
|
||||
self.cudagraph_capture_sizes.sort(reverse=True)
|
||||
self.max_capture_size = self.cudagraph_capture_sizes[
|
||||
0] if self.cudagraph_capture_sizes else 0
|
||||
self.max_capture_size = (
|
||||
self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
|
||||
)
|
||||
|
||||
# pre-compute the mapping from batch size to padded graph size
|
||||
self.bs_to_padded_graph_size = [
|
||||
0 for i in range(self.max_capture_size + 1)
|
||||
]
|
||||
for end, start in zip(self.cudagraph_capture_sizes,
|
||||
self.cudagraph_capture_sizes[1:] + [0]):
|
||||
self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)]
|
||||
for end, start in zip(
|
||||
self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]
|
||||
):
|
||||
for bs in range(start, end):
|
||||
if bs == start:
|
||||
self.bs_to_padded_graph_size[bs] = start
|
||||
else:
|
||||
self.bs_to_padded_graph_size[bs] = end
|
||||
self.bs_to_padded_graph_size[
|
||||
self.max_capture_size] = self.max_capture_size
|
||||
self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def set_splitting_ops_for_v1(self):
|
||||
# NOTE: this function needs to be called only when level is
|
||||
# CompilationLevel.PIECEWISE
|
||||
assert self.level == CompilationLevel.PIECEWISE, (
|
||||
"set_splitting_ops_for_v1 should only be called when "
|
||||
"level is CompilationLevel.PIECEWISE")
|
||||
"level is CompilationLevel.PIECEWISE"
|
||||
)
|
||||
|
||||
if self.use_inductor_graph_partition:
|
||||
self.set_splitting_ops_for_inductor_graph_partition()
|
||||
@@ -608,22 +627,23 @@ class CompilationConfig:
|
||||
# list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
elif len(self.splitting_ops) == 0:
|
||||
logger.warning_once(
|
||||
"Using piecewise compilation with empty splitting_ops")
|
||||
logger.warning_once("Using piecewise compilation with empty splitting_ops")
|
||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.warning_once(
|
||||
"Piecewise compilation with empty splitting_ops do not" \
|
||||
"Piecewise compilation with empty splitting_ops do not"
|
||||
"contains piecewise cudagraph. Setting cudagraph_"
|
||||
"mode to NONE. Hint: If you are using attention backends "
|
||||
"that support cudagraph, consider manually setting "
|
||||
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
|
||||
"full cudagraphs.")
|
||||
"full cudagraphs."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||
logger.warning_once(
|
||||
"Piecewise compilation with empty splitting_ops do not "
|
||||
"contains piecewise cudagraph. Setting cudagraph_mode "
|
||||
"to FULL.")
|
||||
"to FULL."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
|
||||
@@ -632,10 +652,10 @@ class CompilationConfig:
|
||||
use_inductor_graph_partition_msg = (
|
||||
"When use_inductor_graph_partition=True, splitting_ops "
|
||||
"are ignored and set to an empty list. Instead, "
|
||||
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
|
||||
"used to annotate custom ops for graph partition.")
|
||||
if self.splitting_ops is not None and \
|
||||
len(self.splitting_ops) > 0:
|
||||
'"tags=(torch._C.Tag.cudagraph_unsafe, )," is '
|
||||
"used to annotate custom ops for graph partition."
|
||||
)
|
||||
if self.splitting_ops is not None and len(self.splitting_ops) > 0:
|
||||
logger.warning_once(use_inductor_graph_partition_msg)
|
||||
self.splitting_ops = []
|
||||
|
||||
@@ -651,32 +671,38 @@ class CompilationConfig:
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems.")
|
||||
"explicitly if encountering any problems."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
assert not self.splitting_ops_contain_attention(), (
|
||||
"attention ops should not be in splitting_ops "
|
||||
"when enable_attn_fusion is True")
|
||||
"when enable_attn_fusion is True"
|
||||
)
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
return self.splitting_ops is not None and all(
|
||||
op in self.splitting_ops for op in self._attention_ops)
|
||||
op in self.splitting_ops for op in self._attention_ops
|
||||
)
|
||||
|
||||
def is_attention_compiled_piecewise(self) -> bool:
|
||||
use_fx_graph_piecewise_compilation = (
|
||||
self.level == CompilationLevel.PIECEWISE
|
||||
and self.splitting_ops_contain_attention())
|
||||
and self.splitting_ops_contain_attention()
|
||||
)
|
||||
|
||||
inductor_used = (self.level == CompilationLevel.PIECEWISE
|
||||
and self.use_inductor) or (
|
||||
self.level >= CompilationLevel.DYNAMO_AS_IS
|
||||
and self.backend == "inductor")
|
||||
inductor_used = (
|
||||
self.level == CompilationLevel.PIECEWISE and self.use_inductor
|
||||
) or (
|
||||
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
|
||||
)
|
||||
use_inductor_piecewise_compilation = (
|
||||
inductor_used and self.use_inductor_graph_partition
|
||||
and not self.splitting_ops_contain_attention())
|
||||
inductor_used
|
||||
and self.use_inductor_graph_partition
|
||||
and not self.splitting_ops_contain_attention()
|
||||
)
|
||||
|
||||
return use_fx_graph_piecewise_compilation or \
|
||||
use_inductor_piecewise_compilation
|
||||
return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation
|
||||
|
||||
def custom_op_log_check(self):
|
||||
"""
|
||||
@@ -693,13 +719,14 @@ class CompilationConfig:
|
||||
logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
|
||||
logger.debug("disabled custom ops: %s", self.disabled_custom_ops)
|
||||
|
||||
all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops)
|
||||
all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops
|
||||
for op in self.custom_ops:
|
||||
if op in {"all", "none"}:
|
||||
continue
|
||||
|
||||
assert op[0] in {'+', '-'}, "Invalid custom op syntax " \
|
||||
"(should be checked during init)"
|
||||
assert op[0] in {"+", "-"}, (
|
||||
"Invalid custom op syntax (should be checked during init)"
|
||||
)
|
||||
|
||||
# check if op name exists in model
|
||||
op_name = op[1:]
|
||||
@@ -708,10 +735,17 @@ class CompilationConfig:
|
||||
|
||||
# Does op exist at all or is it just not present in this model?
|
||||
# Note: Only imported op classes appear in the registry.
|
||||
missing_str = "doesn't exist (or wasn't imported/registered)" \
|
||||
if op_name not in CustomOp.op_registry \
|
||||
missing_str = (
|
||||
"doesn't exist (or wasn't imported/registered)"
|
||||
if op_name not in CustomOp.op_registry
|
||||
else "not present in model"
|
||||
)
|
||||
|
||||
enable_str = "enabling" if op[0] == '+' else "disabling"
|
||||
logger.warning_once("Op '%s' %s, %s with '%s' has no effect",
|
||||
op_name, missing_str, enable_str, op)
|
||||
enable_str = "enabling" if op[0] == "+" else "disabling"
|
||||
logger.warning_once(
|
||||
"Op '%s' %s, %s with '%s' has no effect",
|
||||
op_name,
|
||||
missing_str,
|
||||
enable_str,
|
||||
op,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user