[Feature] Add Layer-wise NVTX Support (#29990)

Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <maxhu@nvidia.com>
Co-authored-by: Max Hu <maxhu@nvidia.com>
This commit is contained in:
Max Hu
2025-12-05 06:20:07 -05:00
committed by GitHub
parent 3628bcaaf2
commit c2894d3883
5 changed files with 375 additions and 3 deletions

View File

@@ -14,6 +14,7 @@ import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__)
@@ -92,12 +93,29 @@ class TorchCompileWithNoGuardsWrapper:
return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
if self.layerwise_nvtx_tracing_enabled:
args_list = list(args)
kwargs_dict = dict(kwargs)
with layerwise_nvtx_marker_context(
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
self,
in_tensor=args_list,
kwargs=kwargs_dict,
) as ctx:
ctx.result = callable_fn(*args, **kwargs)
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self):
self.compiled = False
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
mode = vllm_config.compilation_config.mode
self.layerwise_nvtx_tracing_enabled = (
vllm_config.observability_config.enable_layerwise_nvtx_tracing
)
if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
@@ -168,13 +186,19 @@ class TorchCompileWithNoGuardsWrapper:
# Make sure a compilation is triggered by clearing dynamo
# cache.
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
return self._compiled_callable(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
else:
with self._dispatch_to_compiled_code():
return self.forward(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self.forward, *args, **kwargs
)
else:
with _compilation_context():
return self._compiled_callable(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
@abstractmethod
def forward(self, *args, **kwargs): ...