[Feature] Add Layer-wise NVTX Support (#29990)
Signed-off-by: Max Hu <hyoung2991@gmail.com> Signed-off-by: Max Hu <maxhu@nvidia.com> Co-authored-by: Max Hu <maxhu@nvidia.com>
This commit is contained in:
@@ -14,6 +14,7 @@ import torch._C._dynamo.guards
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -92,12 +93,29 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
|
||||
if self.layerwise_nvtx_tracing_enabled:
|
||||
args_list = list(args)
|
||||
kwargs_dict = dict(kwargs)
|
||||
with layerwise_nvtx_marker_context(
|
||||
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
|
||||
self,
|
||||
in_tensor=args_list,
|
||||
kwargs=kwargs_dict,
|
||||
) as ctx:
|
||||
ctx.result = callable_fn(*args, **kwargs)
|
||||
return ctx.result
|
||||
return callable_fn(*args, **kwargs)
|
||||
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
mode = vllm_config.compilation_config.mode
|
||||
self.layerwise_nvtx_tracing_enabled = (
|
||||
vllm_config.observability_config.enable_layerwise_nvtx_tracing
|
||||
)
|
||||
if mode is None:
|
||||
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||
|
||||
@@ -168,13 +186,19 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
# Make sure a compilation is triggered by clearing dynamo
|
||||
# cache.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
with self._dispatch_to_compiled_code():
|
||||
return self.forward(*args, **kwargs)
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self.forward, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
with _compilation_context():
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs): ...
|
||||
|
||||
Reference in New Issue
Block a user