[QeRL] Layerwise Reloading (#32133)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -18,6 +18,10 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.model_loader.reload import (
|
||||
record_metadata_for_reloading,
|
||||
set_torchao_reload_attrs,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
@@ -45,7 +49,9 @@ def initialize_model(
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
model = model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
record_metadata_for_reloading(model)
|
||||
return model
|
||||
|
||||
msg = (
|
||||
"vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
@@ -75,27 +81,15 @@ def initialize_model(
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(**kwargs)
|
||||
model = model_class(**kwargs)
|
||||
record_metadata_for_reloading(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def process_weights_after_loading(
|
||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||
) -> None:
|
||||
if getattr(model, "process_weights_after_loading_already_called", False):
|
||||
# In case `process_weights_after_loading` is called multiple times
|
||||
# we'll skip it at later times
|
||||
logger.debug_once(
|
||||
"process_weights_after_loading already called for model %s", model
|
||||
)
|
||||
return
|
||||
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading,
|
||||
)
|
||||
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
@@ -117,6 +111,11 @@ def process_weights_after_loading(
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
# Needed for torchao model reloading via model.reload_weights
|
||||
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
|
||||
if model_config.quantization == "torchao":
|
||||
set_torchao_reload_attrs(model, model_config)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||
|
||||
Reference in New Issue
Block a user