Expand MLA to support most types of quantization (#13181)
This commit is contained in:
@@ -153,6 +153,30 @@ def _initialize_model(
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||
target_device: torch.device) -> None:
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Currently only used by MLA.
|
||||
# NOTE: This intentionally happens after other modules so we can easily
|
||||
# decompress the weights for MLA.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
@@ -376,7 +400,6 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
@@ -394,23 +417,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
_process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
@@ -429,29 +437,15 @@ class DummyModelLoader(BaseModelLoader):
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(
|
||||
module, torch.device(device_config.device)):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
_process_weights_after_loading(model, model_config, target_device)
|
||||
return model.eval()
|
||||
|
||||
|
||||
@@ -632,6 +626,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
target_device = torch.device(device_config.device)
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
@@ -640,18 +635,10 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
model_config.revision)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading(
|
||||
model_config.dtype)
|
||||
_process_weights_after_loading(model, model_config,
|
||||
target_device)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
local_model_path,
|
||||
@@ -1401,16 +1388,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
self._get_weights_iterator(model_weights,
|
||||
model_config.revision))
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
_process_weights_after_loading(model, model_config, target_device)
|
||||
return model.eval()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user