[Misc] Refactor linear layer weight loading; introduce BasevLLMParameter and weight_loader_v2 (#5874)
This commit is contained in:
@@ -13,10 +13,14 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
@@ -288,6 +292,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size,
|
||||
@@ -295,7 +300,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
|
||||
prefix=prefix)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
@@ -337,6 +344,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
@@ -527,6 +537,62 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
"""
|
||||
Handle special case for models where MLP layers are already
|
||||
fused on disk. In this case, we have no shard id. This function
|
||||
determmines the shard id by splitting these layers and then calls
|
||||
the weight loader using the shard id.
|
||||
|
||||
An example of a model with these fused layers:
|
||||
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
||||
"""
|
||||
|
||||
current_shard_offset = 0
|
||||
shard_offsets: List[Tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if isinstance(param, PackedvLLMParameter
|
||||
) and param.packed_dim == param.output_dim:
|
||||
param.adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size, shard_offset=shard_offset)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
|
||||
shard_offset,
|
||||
shard_size)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def weight_loader_v2(self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
param_data = param.data
|
||||
if loaded_shard_id is None:
|
||||
if param.output_dim is None:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layers for the attention's QKV transformation.
|
||||
@@ -598,6 +664,82 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
shard_offset_mapping = {
|
||||
"q": 0,
|
||||
"k": self.num_heads * self.head_size,
|
||||
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
|
||||
}
|
||||
return shard_offset_mapping.get(loaded_shard_id)
|
||||
|
||||
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads * self.head_size,
|
||||
"k": self.num_kv_heads * self.head_size,
|
||||
"v": self.num_kv_heads * self.head_size,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
"""
|
||||
Handle special case for models where QKV layers are already
|
||||
fused on disk. In this case, we have no shard id. This function
|
||||
determmines the shard id by splitting these layers and then calls
|
||||
the weight loader using the shard id.
|
||||
|
||||
An example of a model with these fused layers:
|
||||
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
||||
"""
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
("k", self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
("v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
]
|
||||
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if isinstance(param, PackedvLLMParameter
|
||||
) and param.packed_dim == param.output_dim:
|
||||
param.adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size, shard_offset=shard_offset)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
|
||||
shard_offset,
|
||||
shard_size)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def weight_loader_v2(self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if param.output_dim is None:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
@@ -798,6 +940,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
@@ -805,7 +948,9 @@ class RowParallelLinear(LinearBase):
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
|
||||
prefix=prefix)
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
@@ -850,6 +995,10 @@ class RowParallelLinear(LinearBase):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader_v2(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
|
||||
Reference in New Issue
Block a user