[Core] Allow disabling TP sharding for parallel Linear layer (#23024)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -57,6 +57,8 @@ class BasevLLMParameter(Parameter):
|
||||
weight_loader = _make_synced_weight_loader(weight_loader)
|
||||
|
||||
self._weight_loader = weight_loader
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
@property
|
||||
def weight_loader(self):
|
||||
@@ -116,10 +118,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
return self._output_dim
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
tp_rank * shard_size, shard_size)
|
||||
self.tp_rank * shard_size,
|
||||
shard_size)
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
@@ -127,6 +129,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
@@ -137,11 +140,11 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
param_data = self.data
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
tp_rank * shard_size, shard_size)
|
||||
self.tp_rank * shard_size,
|
||||
shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -161,8 +164,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
|
||||
param_data = self.data
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank //
|
||||
num_heads)
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
@@ -189,10 +192,10 @@ class RowvLLMParameter(BasevLLMParameter):
|
||||
return self._input_dim
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
loaded_weight = loaded_weight.narrow(self.input_dim,
|
||||
tp_rank * shard_size, shard_size)
|
||||
self.tp_rank * shard_size,
|
||||
shard_size)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
@@ -414,9 +417,6 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
"weight_loader": self._fake_weight_loader
|
||||
}
|
||||
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if self.tp_size > 1:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not "
|
||||
"currently support tensor parallelism")
|
||||
|
||||
Reference in New Issue
Block a user