Support RL online quantization with torchao (#23014)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -95,6 +95,13 @@ def initialize_model(
|
||||
|
||||
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||
target_device: torch.device) -> None:
|
||||
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading)
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||
model, model_config)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, QKVCrossParallelLinear):
|
||||
# NOTE(Isotr0py): special case for cross QKV layer because
|
||||
@@ -243,7 +250,7 @@ def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
packed_mapping: dict[str, list[str]]
|
||||
|
||||
Reference in New Issue
Block a user