[Bugfix] Handle process_weights_after_loading for QKVCrossParallelLinear (#15328)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-04-09 01:02:23 +08:00
committed by GitHub
parent 4ebc0b9640
commit 40b4284fe3
3 changed files with 33 additions and 6 deletions

View File

@@ -33,11 +33,15 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.model_loader.tensorizer import (
@@ -160,6 +164,11 @@ def _initialize_model(
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
for _, module in model.named_modules():
if isinstance(module, QKVCrossParallelLinear):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module.process_weights_after_loading()
continue
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading