[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -4666,10 +4666,13 @@ class VllmConfig:
|
||||
|
||||
|
||||
_current_vllm_config: Optional[VllmConfig] = None
|
||||
_current_prefix: Optional[str] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||
def set_current_vllm_config(vllm_config: VllmConfig,
|
||||
check_compile=False,
|
||||
prefix: Optional[str] = None):
|
||||
"""
|
||||
Temporarily set the current vLLM config.
|
||||
Used during model initialization.
|
||||
@@ -4677,12 +4680,14 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||
so that all modules can access it, e.g. custom ops
|
||||
can access the vLLM config to determine how to dispatch.
|
||||
"""
|
||||
global _current_vllm_config
|
||||
global _current_vllm_config, _current_prefix
|
||||
old_vllm_config = _current_vllm_config
|
||||
old_prefix = _current_prefix
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
num_models_seen = compilation_counter.num_models_seen
|
||||
try:
|
||||
_current_vllm_config = vllm_config
|
||||
_current_prefix = prefix
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
@@ -4706,6 +4711,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
|
||||
vllm_config.model_config.model)
|
||||
finally:
|
||||
_current_vllm_config = old_vllm_config
|
||||
_current_prefix = old_prefix
|
||||
|
||||
|
||||
def get_current_vllm_config() -> VllmConfig:
|
||||
@@ -4719,6 +4725,15 @@ def get_current_vllm_config() -> VllmConfig:
|
||||
return _current_vllm_config
|
||||
|
||||
|
||||
def get_current_model_prefix() -> str:
|
||||
"""
|
||||
Get the prefix of the model that's currently being initialized.
|
||||
"""
|
||||
assert _current_prefix is not None, \
|
||||
"Current model prefix is not set. "
|
||||
return _current_prefix
|
||||
|
||||
|
||||
def contains_object_print(text):
|
||||
"""
|
||||
Check if the text looks like a printed Python object, e.g.
|
||||
|
||||
Reference in New Issue
Block a user