[4/N][torch.compile] clean up set_torch_compile_backend (#10401)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-17 23:57:20 -08:00
committed by GitHub
parent 47826cacf0
commit 51bb12d17b
7 changed files with 49 additions and 42 deletions

View File

@@ -22,7 +22,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once)
identity, print_warning_once, resolve_obj_by_qualname)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
@@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing.
""" # noqa
level: int = 0
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
use_inductor: bool = True
@@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel):
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
from torch._dynamo.backends.registry import list_backends
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available