[Core] Allow disabling TP sharding for parallel Linear layer (#23024)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -223,6 +223,7 @@ class LinearBase(CustomOp):
|
||||
quant_config: Quantization configure.
|
||||
prefix: Prefix for parameter names.
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: If true, tensor parallelism will be disabled for this layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -235,6 +236,7 @@ class LinearBase(CustomOp):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -254,6 +256,17 @@ class LinearBase(CustomOp):
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.disable_tp = disable_tp
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
|
||||
def __post_init__(self):
|
||||
for param in self.parameters():
|
||||
if isinstance(param, BasevLLMParameter):
|
||||
param.tp_rank = self.tp_rank
|
||||
param.tp_size = self.tp_size
|
||||
|
||||
|
||||
@CustomOp.register("replicated_linear")
|
||||
@@ -270,6 +283,7 @@ class ReplicatedLinear(LinearBase):
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: Take no effect for replicated linear layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -283,26 +297,21 @@ class ReplicatedLinear(LinearBase):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# If MergedReplicatedLinear, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = self.output_sizes
|
||||
else:
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size,
|
||||
self.output_partition_sizes,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
@@ -358,74 +367,6 @@ class ReplicatedLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
class MergedReplicatedLinear(ReplicatedLinear):
|
||||
"""Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_sizes: list of output dimensions of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
super().__init__(input_size,
|
||||
sum(output_sizes),
|
||||
bias,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Union[Parameter, BasevLLMParameter],
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
assert loaded_shard_id is not None
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8LinearMethod, Fp8MoEMethod)
|
||||
assert self.quant_method is not None
|
||||
assert isinstance(self.quant_method,
|
||||
(Fp8LinearMethod, Fp8MoEMethod))
|
||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (
|
||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||||
block_n)
|
||||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||||
block_n)
|
||||
elif isinstance(param, PerTensorScaleParameter):
|
||||
shard_offset = loaded_shard_id
|
||||
shard_size = 1
|
||||
else:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
|
||||
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||
|
||||
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
@@ -448,7 +389,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -464,9 +407,13 @@ class ColumnParallelLinear(LinearBase):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
@@ -483,7 +430,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
@@ -512,8 +460,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
@@ -554,7 +500,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
def weight_loader_v2(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
@@ -570,7 +517,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
if self.gather_output and self.tp_size > 1:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
@@ -584,7 +531,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", output_features={self.output_size_per_partition}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||
s += f", tp_size={self.tp_size}"
|
||||
s += f", gather_output={self.gather_output}"
|
||||
return s
|
||||
|
||||
@@ -611,6 +558,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: If true, all weights matrix won't be sharded, this layer
|
||||
will be treated as a "Replicated" MergedLinear.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -625,10 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
|
||||
assert all(output_size % self.tp_size == 0
|
||||
for output_size in output_sizes)
|
||||
@@ -640,7 +592,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@@ -832,8 +785,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8LinearMethod, Fp8MoEMethod)
|
||||
@@ -845,17 +796,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (
|
||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||||
block_n) // tp_size
|
||||
block_n) // self.tp_size
|
||||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||||
block_n // tp_size)
|
||||
block_n // self.tp_size)
|
||||
else:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
shard_offset = sum(
|
||||
self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
@@ -883,6 +836,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -898,6 +852,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
@@ -906,7 +861,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
@@ -932,7 +888,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
shard_offset_mapping = {
|
||||
@@ -993,10 +950,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight,
|
||||
shard_id=0,
|
||||
tp_rank=self.tp_rank)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight)
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight,
|
||||
tp_rank=self.tp_rank)
|
||||
return
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
@@ -1020,7 +980,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@@ -1226,6 +1187,7 @@ class RowParallelLinear(LinearBase):
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.down_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -1241,10 +1203,13 @@ class RowParallelLinear(LinearBase):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
@@ -1255,7 +1220,8 @@ class RowParallelLinear(LinearBase):
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
@@ -1339,10 +1305,9 @@ class RowParallelLinear(LinearBase):
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
|
||||
Reference in New Issue
Block a user