[UX] Add vLLM model inspection view (#29450)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -348,6 +348,9 @@ class LLM:
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
|
||||
# Cache for __repr__ to avoid repeated collective_rpc calls
|
||||
self._cached_repr: str | None = None
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.llm_engine.get_tokenizer()
|
||||
|
||||
@@ -1786,3 +1789,16 @@ class LLM:
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a transformers-style hierarchical view of the model."""
|
||||
# Cache the result to avoid repeated collective_rpc calls
|
||||
if self._cached_repr is None:
|
||||
results = self.llm_engine.collective_rpc("get_model_inspection")
|
||||
# In distributed settings, we get results from all workers
|
||||
# Just return the first one (they should all be the same)
|
||||
if results:
|
||||
self._cached_repr = results[0]
|
||||
else:
|
||||
self._cached_repr = f"LLM(model={self.model_config.model!r})"
|
||||
return self._cached_repr
|
||||
|
||||
@@ -250,6 +250,7 @@ if TYPE_CHECKING:
|
||||
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
VLLM_USE_V2_MODEL_RUNNER: bool = False
|
||||
VLLM_LOG_MODEL_INSPECTION: bool = False
|
||||
VLLM_DEBUG_MFU_METRICS: bool = False
|
||||
|
||||
|
||||
@@ -1595,6 +1596,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
|
||||
),
|
||||
# Log model inspection after loading.
|
||||
# If enabled, logs a transformers-style hierarchical view of the model
|
||||
# with quantization methods and attention backends.
|
||||
"VLLM_LOG_MODEL_INSPECTION": lambda: bool(
|
||||
int(os.getenv("VLLM_LOG_MODEL_INSPECTION", "0"))
|
||||
),
|
||||
# Debug logging for --enable-mfu-metrics
|
||||
"VLLM_DEBUG_MFU_METRICS": lambda: bool(
|
||||
int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0"))
|
||||
|
||||
@@ -285,5 +285,5 @@ class ApplyRotaryEmb(CustomOp):
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"is_neox_style={self.is_neox_style}"
|
||||
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
||||
s += f", enable_fp32_compute={self.enable_fp32_compute}"
|
||||
return s
|
||||
|
||||
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
@@ -50,8 +51,21 @@ class BaseModelLoader(ABC):
|
||||
vllm_config=vllm_config, model_config=model_config
|
||||
)
|
||||
|
||||
log_model_inspection(model)
|
||||
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
def log_model_inspection(model: nn.Module) -> None:
|
||||
"""Log model structure if VLLM_LOG_MODEL_INSPECTION=1."""
|
||||
if not envs.VLLM_LOG_MODEL_INSPECTION:
|
||||
return
|
||||
|
||||
from vllm.model_inspection import format_model_inspection
|
||||
|
||||
logger.info("vLLM model structure:\n%s", format_model_inspection(model))
|
||||
|
||||
136
vllm/model_inspection.py
Normal file
136
vllm/model_inspection.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Model inspection utilities for vLLM."""
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _get_module_info(module: nn.Module) -> str:
|
||||
"""Get info string for a module."""
|
||||
class_name = type(module).__name__
|
||||
parts = []
|
||||
|
||||
# Add quant_method if present
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_name = type(quant_method).__name__
|
||||
# For CompressedTensors, show the underlying scheme instead
|
||||
scheme = getattr(module, "scheme", None)
|
||||
if scheme is not None:
|
||||
quant_name = type(scheme).__name__
|
||||
# Skip unquantized methods
|
||||
if "Unquantized" not in quant_name:
|
||||
parts.append(f"quant={quant_name}")
|
||||
|
||||
# If module has extra_repr, use it
|
||||
if hasattr(module, "extra_repr"):
|
||||
parts.append(module.extra_repr().replace("\n", ""))
|
||||
|
||||
if parts:
|
||||
return f"{class_name}({', '.join(parts)})"
|
||||
|
||||
# For unknown modules, use the default PyTorch repr
|
||||
return str(module)
|
||||
|
||||
|
||||
def _get_child_signature(child: nn.Module) -> str:
|
||||
"""Get a signature for a child module to detect duplicates."""
|
||||
lines = []
|
||||
for name, submodule in child.named_modules():
|
||||
lines.append(f"{name}:{_get_module_info(submodule)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_index_ranges(indices: list[int]) -> str:
|
||||
"""Format indices into range notation (e.g., [0,1,2,4,5,6] -> '0-2, 4-6')."""
|
||||
indices = sorted(indices)
|
||||
ranges = []
|
||||
start = end = indices[0]
|
||||
|
||||
for idx in indices[1:]:
|
||||
if idx == end + 1:
|
||||
end = idx
|
||||
else:
|
||||
ranges.append(str(start) if start == end else f"{start}-{end}")
|
||||
start = end = idx
|
||||
|
||||
ranges.append(str(start) if start == end else f"{start}-{end}")
|
||||
return ", ".join(ranges)
|
||||
|
||||
|
||||
def _format_module_tree(
|
||||
module: nn.Module,
|
||||
name: str = "",
|
||||
indent: int = 0,
|
||||
) -> list[str]:
|
||||
"""Format a module tree with indentation, grouping identical layers.
|
||||
|
||||
Produces output like:
|
||||
(layers): ModuleList(
|
||||
(0-27, 29-47): 47 x LlamaDecoderLayer(
|
||||
...
|
||||
)
|
||||
(28, 48): 2 x DifferentDecoderLayer(
|
||||
...
|
||||
)
|
||||
)
|
||||
"""
|
||||
lines = []
|
||||
prefix = " " * indent
|
||||
children = list(module.named_children())
|
||||
|
||||
# Leaf node - just output the module info
|
||||
if not children:
|
||||
info = _get_module_info(module)
|
||||
lines.append(f"{prefix}({name}): {info}" if name else f"{prefix}{info}")
|
||||
return lines
|
||||
|
||||
# Non-leaf node - output opening line and recurse into children
|
||||
info = _get_module_info(module)
|
||||
lines.append(f"{prefix}({name}): {info}(" if name else f"{prefix}{info}(")
|
||||
|
||||
# Separate numbered children (e.g., "0", "1") from named ones (e.g., "norm")
|
||||
numbered: list[tuple[int, nn.Module]] = []
|
||||
non_numbered: list[tuple[str, nn.Module]] = []
|
||||
for child_name, child_module in children:
|
||||
try:
|
||||
numbered.append((int(child_name), child_module))
|
||||
except ValueError:
|
||||
non_numbered.append((child_name, child_module))
|
||||
|
||||
# Group numbered children by structure signature to collapse identical layers
|
||||
# e.g., layers 0-27 and 29-47 with same structure become "(0-27, 29-47): 47 x"
|
||||
if numbered:
|
||||
sig_to_group: dict[str, list[tuple[int, nn.Module]]] = {}
|
||||
for idx, child_module in numbered:
|
||||
sig = _get_child_signature(child_module)
|
||||
sig_to_group.setdefault(sig, []).append((idx, child_module))
|
||||
|
||||
# Output groups sorted by first index
|
||||
for group in sorted(sig_to_group.values(), key=lambda g: g[0][0]):
|
||||
indices = [idx for idx, _ in group]
|
||||
representative = group[0][1]
|
||||
child_lines = _format_module_tree(representative, "", indent + 1)
|
||||
first_line = child_lines[0].lstrip()
|
||||
child_prefix = " " * (indent + 1)
|
||||
|
||||
if len(indices) > 1:
|
||||
range_str = _format_index_ranges(indices)
|
||||
child_lines[0] = (
|
||||
f"{child_prefix}({range_str}): {len(indices)} x {first_line}"
|
||||
)
|
||||
else:
|
||||
child_lines[0] = f"{child_prefix}({indices[0]}): {first_line}"
|
||||
lines.extend(child_lines)
|
||||
|
||||
# Output non-numbered children (e.g., "embed_tokens", "norm")
|
||||
for child_name, child_module in non_numbered:
|
||||
lines.extend(_format_module_tree(child_module, child_name, indent + 1))
|
||||
|
||||
lines.append(f"{prefix})")
|
||||
return lines
|
||||
|
||||
|
||||
def format_model_inspection(model: nn.Module) -> str:
|
||||
"""Format a model into a transformers-style hierarchical string."""
|
||||
return "\n".join(_format_module_tree(model))
|
||||
@@ -118,6 +118,12 @@ class WorkerBase:
|
||||
"""Apply a function on the model inside this worker."""
|
||||
return fn(self.get_model())
|
||||
|
||||
def get_model_inspection(self) -> str:
|
||||
"""Return a transformers-style hierarchical view of the model."""
|
||||
from vllm.model_inspection import format_model_inspection
|
||||
|
||||
return format_model_inspection(self.get_model())
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user