diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 9a4da0619..093737f11 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -671,20 +671,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA): self.reset_lora(index) self.adapter_enabled[index] = 1 - num_experts = self.w13_lora_a_stacked[0].shape[1] w13_lora_a, w2_lora_a = lora_a w13_lora_b, w2_lora_b = lora_b - # (num_experts,rank,input_size) - w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1]) - w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1]) - # (output_size,rank,num_experts) - w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], -1, num_experts) - w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], -1, num_experts) - # (num_experts,output_size,rank) - w13_lora_b = w13_lora_b.permute(2, 0, 1) - w2_lora_b = w2_lora_b.permute(2, 0, 1) - sliced_w13_lora_a = self._slice_w13_a(w13_lora_a) sliced_w13_lora_b = self._slice_w13_b(w13_lora_b) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 69787f8f0..5ef1b823c 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -256,61 +256,7 @@ class LoRAModelManager: if not module_lora: module.reset_lora(index) continue - # Note (gnovack) - If MOE lora weights are not split into - # num_experts chunks, we split them here - if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor( - module_lora.lora_a - ): - # Handle PEFT file format where experts.base_layer is the - # gate_up_proj and experts is the down_proj - gate_up_proj_lora = self._get_lora_layer_weights( - lora_model, module_name + ".base_layer" - ) - down_proj_lora = module_lora - # FIXME Edge case where LoRA is not added to gate_up_proj - # or down_proj - assert gate_up_proj_lora is not None - assert down_proj_lora is not None - if self._is_3d_moe_model: - module_lora.lora_a = [ - gate_up_proj_lora.lora_a, - down_proj_lora.lora_a, - ] - module_lora.lora_b = [ - gate_up_proj_lora.lora_b, - down_proj_lora.lora_b, - ] - else: - # Some 3D MoE models haven't added the `is_3d_moe_weight` - # attribute yet, so fallback here - num_experts = module_lora.lora_a.shape[0] // module_lora.rank - gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) - up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) - - gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk( - num_experts, dim=-1 - ) - up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk( - num_experts, dim=-1 - ) - - down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0) - down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1) - - lora_a = [] - lora_b = [] - for i in range(num_experts): - lora_a.append(gate_proj_a[i]) - lora_a.append(down_proj_a[i]) - lora_a.append(up_proj_a[i]) - - lora_b.append(gate_proj_b[i]) - lora_b.append(down_proj_b[i]) - lora_b.append(up_proj_b[i]) - - module_lora.lora_a = lora_a - module_lora.lora_b = lora_b module.set_lora( index, module_lora.lora_a, @@ -627,6 +573,10 @@ class LoRAModelManager: for lora in lora_model.loras.values(): lora.optimize() + for module_name, module in self.modules.items(): + if isinstance(module, FusedMoE3DWithLoRA): + self._stack_moe_lora_weights(lora_model, module, module_name) + first_lora: LoRALayerWeights = next(iter(lora_model.loras.values())) assert first_lora.lora_a is not None if isinstance(first_lora.lora_a, list): @@ -653,6 +603,91 @@ class LoRAModelManager: lora.lora_a = lora.lora_a.pin_memory() lora.lora_b = lora.lora_b.pin_memory() + def _stack_moe_lora_weights( + self, lora_model: LoRAModel, module: FusedMoE3DWithLoRA, module_name: str + ): + module_lora = self._get_lora_layer_weights(lora_model, module_name) + + # Note (gnovack) - If MOE lora weights are not split into + # num_experts chunks, we split them here + if module_lora and torch.is_tensor(module_lora.lora_a): + # Handle PEFT file format where experts.base_layer is the + # gate_up_proj and experts is the down_proj + gate_up_proj_lora = self._get_lora_layer_weights( + lora_model, module_name + ".base_layer" + ) + down_proj_lora = module_lora + # FIXME Edge case where LoRA is not added to gate_up_proj + # or down_proj + assert gate_up_proj_lora is not None + assert down_proj_lora is not None + if self._is_3d_moe_model: + num_experts = module.w13_lora_a_stacked[0].shape[1] + + # (num_experts,rank,input_size) + gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape( + num_experts, -1, gate_up_proj_lora.lora_a.shape[-1] + ) + down_proj_lora.lora_a = down_proj_lora.lora_a.reshape( + num_experts, -1, down_proj_lora.lora_a.shape[-1] + ) + + # (output_size,num_experts,rank) + gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape( + gate_up_proj_lora.lora_b.shape[0], -1, num_experts + ) + down_proj_lora.lora_b = down_proj_lora.lora_b.reshape( + down_proj_lora.lora_b.shape[0], -1, num_experts + ) + + # (num_experts,output_size,rank) + gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute( + 2, 0, 1 + ).contiguous() + down_proj_lora.lora_b = down_proj_lora.lora_b.permute( + 2, 0, 1 + ).contiguous() + + module_lora.lora_a = [ + gate_up_proj_lora.lora_a, + down_proj_lora.lora_a, + ] + module_lora.lora_b = [ + gate_up_proj_lora.lora_b, + down_proj_lora.lora_b, + ] + else: + # Some 3D MoE models haven't added the `is_3d_moe_weight` + # attribute yet, so fallback here + num_experts = module_lora.lora_a.shape[0] // module_lora.rank + + gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) + up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) + + gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk( + num_experts, dim=-1 + ) + up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk( + num_experts, dim=-1 + ) + + down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0) + down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1) + + lora_a = [] + lora_b = [] + for i in range(num_experts): + lora_a.append(gate_proj_a[i]) + lora_a.append(down_proj_a[i]) + lora_a.append(up_proj_a[i]) + + lora_b.append(gate_proj_b[i]) + lora_b.append(down_proj_b[i]) + lora_b.append(up_proj_b[i]) + + module_lora.lora_a = lora_a + module_lora.lora_b = lora_b + def _get_lora_layer_weights( self, lora_model: LoRAModel, module_name: str ) -> LoRALayerWeights | None: