pin lora_b moe weights on cpu (#31317)
Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
@@ -671,20 +671,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
|
||||
num_experts = self.w13_lora_a_stacked[0].shape[1]
|
||||
w13_lora_a, w2_lora_a = lora_a
|
||||
w13_lora_b, w2_lora_b = lora_b
|
||||
|
||||
# (num_experts,rank,input_size)
|
||||
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
|
||||
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
|
||||
# (output_size,rank,num_experts)
|
||||
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], -1, num_experts)
|
||||
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], -1, num_experts)
|
||||
# (num_experts,output_size,rank)
|
||||
w13_lora_b = w13_lora_b.permute(2, 0, 1)
|
||||
w2_lora_b = w2_lora_b.permute(2, 0, 1)
|
||||
|
||||
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
|
||||
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
|
||||
|
||||
|
||||
@@ -256,61 +256,7 @@ class LoRAModelManager:
|
||||
if not module_lora:
|
||||
module.reset_lora(index)
|
||||
continue
|
||||
# Note (gnovack) - If MOE lora weights are not split into
|
||||
# num_experts chunks, we split them here
|
||||
if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
|
||||
module_lora.lora_a
|
||||
):
|
||||
# Handle PEFT file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
gate_up_proj_lora = self._get_lora_layer_weights(
|
||||
lora_model, module_name + ".base_layer"
|
||||
)
|
||||
down_proj_lora = module_lora
|
||||
# FIXME Edge case where LoRA is not added to gate_up_proj
|
||||
# or down_proj
|
||||
assert gate_up_proj_lora is not None
|
||||
assert down_proj_lora is not None
|
||||
if self._is_3d_moe_model:
|
||||
module_lora.lora_a = [
|
||||
gate_up_proj_lora.lora_a,
|
||||
down_proj_lora.lora_a,
|
||||
]
|
||||
module_lora.lora_b = [
|
||||
gate_up_proj_lora.lora_b,
|
||||
down_proj_lora.lora_b,
|
||||
]
|
||||
else:
|
||||
# Some 3D MoE models haven't added the `is_3d_moe_weight`
|
||||
# attribute yet, so fallback here
|
||||
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
|
||||
|
||||
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
|
||||
gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
|
||||
num_experts, dim=-1
|
||||
)
|
||||
up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
|
||||
num_experts, dim=-1
|
||||
)
|
||||
|
||||
down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)
|
||||
|
||||
lora_a = []
|
||||
lora_b = []
|
||||
for i in range(num_experts):
|
||||
lora_a.append(gate_proj_a[i])
|
||||
lora_a.append(down_proj_a[i])
|
||||
lora_a.append(up_proj_a[i])
|
||||
|
||||
lora_b.append(gate_proj_b[i])
|
||||
lora_b.append(down_proj_b[i])
|
||||
lora_b.append(up_proj_b[i])
|
||||
|
||||
module_lora.lora_a = lora_a
|
||||
module_lora.lora_b = lora_b
|
||||
module.set_lora(
|
||||
index,
|
||||
module_lora.lora_a,
|
||||
@@ -627,6 +573,10 @@ class LoRAModelManager:
|
||||
for lora in lora_model.loras.values():
|
||||
lora.optimize()
|
||||
|
||||
for module_name, module in self.modules.items():
|
||||
if isinstance(module, FusedMoE3DWithLoRA):
|
||||
self._stack_moe_lora_weights(lora_model, module, module_name)
|
||||
|
||||
first_lora: LoRALayerWeights = next(iter(lora_model.loras.values()))
|
||||
assert first_lora.lora_a is not None
|
||||
if isinstance(first_lora.lora_a, list):
|
||||
@@ -653,6 +603,91 @@ class LoRAModelManager:
|
||||
lora.lora_a = lora.lora_a.pin_memory()
|
||||
lora.lora_b = lora.lora_b.pin_memory()
|
||||
|
||||
def _stack_moe_lora_weights(
|
||||
self, lora_model: LoRAModel, module: FusedMoE3DWithLoRA, module_name: str
|
||||
):
|
||||
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
||||
|
||||
# Note (gnovack) - If MOE lora weights are not split into
|
||||
# num_experts chunks, we split them here
|
||||
if module_lora and torch.is_tensor(module_lora.lora_a):
|
||||
# Handle PEFT file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
gate_up_proj_lora = self._get_lora_layer_weights(
|
||||
lora_model, module_name + ".base_layer"
|
||||
)
|
||||
down_proj_lora = module_lora
|
||||
# FIXME Edge case where LoRA is not added to gate_up_proj
|
||||
# or down_proj
|
||||
assert gate_up_proj_lora is not None
|
||||
assert down_proj_lora is not None
|
||||
if self._is_3d_moe_model:
|
||||
num_experts = module.w13_lora_a_stacked[0].shape[1]
|
||||
|
||||
# (num_experts,rank,input_size)
|
||||
gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape(
|
||||
num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
|
||||
)
|
||||
down_proj_lora.lora_a = down_proj_lora.lora_a.reshape(
|
||||
num_experts, -1, down_proj_lora.lora_a.shape[-1]
|
||||
)
|
||||
|
||||
# (output_size,num_experts,rank)
|
||||
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape(
|
||||
gate_up_proj_lora.lora_b.shape[0], -1, num_experts
|
||||
)
|
||||
down_proj_lora.lora_b = down_proj_lora.lora_b.reshape(
|
||||
down_proj_lora.lora_b.shape[0], -1, num_experts
|
||||
)
|
||||
|
||||
# (num_experts,output_size,rank)
|
||||
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute(
|
||||
2, 0, 1
|
||||
).contiguous()
|
||||
down_proj_lora.lora_b = down_proj_lora.lora_b.permute(
|
||||
2, 0, 1
|
||||
).contiguous()
|
||||
|
||||
module_lora.lora_a = [
|
||||
gate_up_proj_lora.lora_a,
|
||||
down_proj_lora.lora_a,
|
||||
]
|
||||
module_lora.lora_b = [
|
||||
gate_up_proj_lora.lora_b,
|
||||
down_proj_lora.lora_b,
|
||||
]
|
||||
else:
|
||||
# Some 3D MoE models haven't added the `is_3d_moe_weight`
|
||||
# attribute yet, so fallback here
|
||||
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
|
||||
|
||||
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
|
||||
gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
|
||||
num_experts, dim=-1
|
||||
)
|
||||
up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
|
||||
num_experts, dim=-1
|
||||
)
|
||||
|
||||
down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)
|
||||
|
||||
lora_a = []
|
||||
lora_b = []
|
||||
for i in range(num_experts):
|
||||
lora_a.append(gate_proj_a[i])
|
||||
lora_a.append(down_proj_a[i])
|
||||
lora_a.append(up_proj_a[i])
|
||||
|
||||
lora_b.append(gate_proj_b[i])
|
||||
lora_b.append(down_proj_b[i])
|
||||
lora_b.append(up_proj_b[i])
|
||||
|
||||
module_lora.lora_a = lora_a
|
||||
module_lora.lora_b = lora_b
|
||||
|
||||
def _get_lora_layer_weights(
|
||||
self, lora_model: LoRAModel, module_name: str
|
||||
) -> LoRALayerWeights | None:
|
||||
|
||||
Reference in New Issue
Block a user