[Fix] Align fused moe lora_b shape with peft (#31534)

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
B-201
2025-12-31 09:44:59 +08:00
committed by GitHub
parent e1ee11b2a5
commit ecd49ce7e6
3 changed files with 9 additions and 9 deletions

View File

@@ -392,7 +392,7 @@ th {
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ |
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ |
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ |
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ |
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | ✅︎ | ✅︎ |
| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ |
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ |
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ |

View File

@@ -34,9 +34,9 @@ The Competition_ID of competition_record is the foreign key of Competition_ID of
###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
"SELECT avg(Working_Horses) FROM farm WHERE Total_Horses > 5000",
"SELECT max(Cows) , min(Cows) FROM farm",
"SELECT max(Cows) , min(Cows) FROM farm",
]

View File

@@ -679,12 +679,12 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
# (num_experts,rank,input_size)
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
# (output_size,num_experts,rank)
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
# (output_size,rank,num_experts)
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], -1, num_experts)
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], -1, num_experts)
# (num_experts,output_size,rank)
w13_lora_b = w13_lora_b.permute(1, 0, 2)
w2_lora_b = w2_lora_b.permute(1, 0, 2)
w13_lora_b = w13_lora_b.permute(2, 0, 1)
w2_lora_b = w2_lora_b.permute(2, 0, 1)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)