[Model] Support for fairseq2 Llama (#11442)

Signed-off-by: Martin Gleize <mgleize@meta.com>
Co-authored-by: mgleize user <mgleize@a100-st-p4de24xlarge-4.fair-a100.hpcaas>
This commit is contained in:
Martin Gleize
2025-01-19 19:40:40 +01:00
committed by GitHub
parent 81763c58a0
commit bbe5f9de7d
7 changed files with 197 additions and 21 deletions

View File

@@ -344,11 +344,13 @@ class ColumnParallelLinear(LinearBase):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
param_data = param.data
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@@ -546,6 +548,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
@@ -554,9 +561,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
@@ -941,6 +946,11 @@ class QKVParallelLinear(ColumnParallelLinear):
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
@@ -964,9 +974,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
@@ -1070,6 +1078,10 @@ class RowParallelLinear(LinearBase):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -1085,9 +1097,7 @@ class RowParallelLinear(LinearBase):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
param_data = param.data
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,