Support RL online quantization with torchao (#23014)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -261,8 +261,35 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
# set to True, it means that this is either offline quantization case
|
||||
# or the first run of online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||
model, "weight_metadata_and_attr_saved", False)
|
||||
|
||||
if model_config.quantization is None:
|
||||
# model is not quantized
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
elif offline_quantization_or_first_run_of_online_quantization:
|
||||
# case 1: offline quantized checkpoint
|
||||
# case 2: Step I1 first run of weight loading with
|
||||
# online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
else:
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
load_weights_and_online_quantize)
|
||||
|
||||
# subsequent runs of weight loading with online
|
||||
# quantization
|
||||
loaded_weights = load_weights_and_online_quantize(
|
||||
self, model, model_config)
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
|
||||
Reference in New Issue
Block a user