Support RL online quantization with torchao (#23014)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-10-01 16:39:29 -07:00
committed by GitHub
parent 4134312b35
commit c31246800c
6 changed files with 465 additions and 16 deletions

View File

@@ -261,8 +261,35 @@ class DefaultModelLoader(BaseModelLoader):
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False)
if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize)
# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(
self, model, model_config)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",