[Core] Allow disabling TP sharding for parallel Linear layer (#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py
2025-09-06 13:53:58 +08:00
committed by GitHub
parent 6432739ef1
commit 53b19ccdd5
7 changed files with 203 additions and 280 deletions

View File

@@ -223,6 +223,7 @@ class LinearBase(CustomOp):
quant_config: Quantization configure.
prefix: Prefix for parameter names.
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, tensor parallelism will be disabled for this layer.
"""
def __init__(
@@ -235,6 +236,7 @@ class LinearBase(CustomOp):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
super().__init__()
@@ -254,6 +256,17 @@ class LinearBase(CustomOp):
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
def __post_init__(self):
for param in self.parameters():
if isinstance(param, BasevLLMParameter):
param.tp_rank = self.tp_rank
param.tp_size = self.tp_size
@CustomOp.register("replicated_linear")
@@ -270,6 +283,7 @@ class ReplicatedLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
def __init__(
@@ -283,26 +297,21 @@ class ReplicatedLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
self.output_partition_sizes,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
@@ -358,74 +367,6 @@ class ReplicatedLinear(LinearBase):
return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n)
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n)
elif isinstance(param, PerTensorScaleParameter):
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
@@ -448,7 +389,9 @@ class ColumnParallelLinear(LinearBase):
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
@@ -464,9 +407,13 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
@@ -483,7 +430,8 @@ class ColumnParallelLinear(LinearBase):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
@@ -512,8 +460,6 @@ class ColumnParallelLinear(LinearBase):
else:
self.register_parameter("bias", None)
self.tp_rank = get_tensor_model_parallel_rank()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
@@ -554,7 +500,8 @@ class ColumnParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
@@ -570,7 +517,7 @@ class ColumnParallelLinear(LinearBase):
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
@@ -584,7 +531,7 @@ class ColumnParallelLinear(LinearBase):
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", tp_size={self.tp_size}"
s += f", gather_output={self.gather_output}"
return s
@@ -611,6 +558,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""
def __init__(
@@ -625,10 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
@@ -640,7 +592,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def weight_loader(self,
param: Parameter,
@@ -832,8 +785,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
@@ -845,17 +796,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
block_n) // self.tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
block_n // self.tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
shard_offset = sum(
self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)
class QKVParallelLinear(ColumnParallelLinear):
@@ -883,6 +836,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
@@ -898,6 +852,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
@@ -906,7 +861,8 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
@@ -932,7 +888,8 @@ class QKVParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
@@ -993,10 +950,13 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
param.load_qkv_weight(loaded_weight=loaded_weight,
shard_id=0,
tp_rank=self.tp_rank)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
param.load_qkv_weight(loaded_weight=loaded_weight,
tp_rank=self.tp_rank)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
@@ -1020,7 +980,8 @@ class QKVParallelLinear(ColumnParallelLinear):
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)
def weight_loader(self,
param: Parameter,
@@ -1226,6 +1187,7 @@ class RowParallelLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
@@ -1241,10 +1203,13 @@ class RowParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
@@ -1255,7 +1220,8 @@ class RowParallelLinear(LinearBase):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@@ -1339,10 +1305,9 @@ class RowParallelLinear(LinearBase):
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None