[Core] Allow disabling TP sharding for parallel Linear layer (#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py
2025-09-06 13:53:58 +08:00
committed by GitHub
parent 6432739ef1
commit 53b19ccdd5
7 changed files with 203 additions and 280 deletions

View File

@@ -57,6 +57,8 @@ class BasevLLMParameter(Parameter):
weight_loader = _make_synced_weight_loader(weight_loader)
self._weight_loader = weight_loader
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
@property
def weight_loader(self):
@@ -116,10 +118,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
self.tp_rank * shard_size,
shard_size)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
@@ -127,6 +129,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance(
self,
@@ -137,11 +140,11 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
self.tp_rank * shard_size,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
@@ -161,8 +164,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset=shard_offset, shard_size=shard_size)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank //
num_heads)
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
@@ -189,10 +192,10 @@ class RowvLLMParameter(BasevLLMParameter):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
self.tp_rank * shard_size,
shard_size)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
@@ -414,9 +417,6 @@ class SharedWeightParameter(BasevLLMParameter):
"weight_loader": self._fake_weight_loader
}
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > 1:
raise NotImplementedError(f"{self.__class__.__name__} does not "
"currently support tensor parallelism")