[CI/Build] Fix pre-commit errors (#13696)
This commit is contained in:
@@ -134,7 +134,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
return tp_size - ngroups
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
@@ -168,12 +168,9 @@ def mamba_v2_sharded_weight_loader(
|
||||
# - compute the rank into the loaded shard.
|
||||
# - if there is replication, different TP shards will
|
||||
# take from the same rank.
|
||||
if duplicate_groups:
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0
|
||||
else:
|
||||
rank = tp_rank
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0 if duplicate_groups else tp_rank
|
||||
|
||||
# - leftmost boundary index into loaded weight.
|
||||
loaded_skip = rank * shard_size
|
||||
@@ -247,7 +244,7 @@ class MambaMixer2(CustomOp):
|
||||
assert num_heads % self.tp_size == 0, \
|
||||
"Tensor parallel world size must divide num heads."
|
||||
|
||||
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
||||
(
|
||||
"If tensor parallel world size does not divide num_heads, "
|
||||
|
||||
Reference in New Issue
Block a user