Support RL online quantization with torchao (#23014)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-10-01 16:39:29 -07:00
committed by GitHub
parent 4134312b35
commit c31246800c
6 changed files with 465 additions and 16 deletions

View File

@@ -95,6 +95,13 @@ def initialize_model(
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
maybe_save_metadata_and_attributes_for_weight_reloading)
maybe_save_metadata_and_attributes_for_weight_reloading(
model, model_config)
for _, module in model.named_modules():
if isinstance(module, QKVCrossParallelLinear):
# NOTE(Isotr0py): special case for cross QKV layer because
@@ -243,7 +250,7 @@ def get_architecture_class_name(model_config: ModelConfig) -> str:
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: dict[str, list[str]]