[Fix] Align fused moe lora_b shape with peft (#31534)
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
@@ -392,7 +392,7 @@ th {
|
||||
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ |
|
||||
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ |
|
||||
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ |
|
||||
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ |
|
||||
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | ✅︎ | ✅︎ |
|
||||
| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@@ -34,9 +34,9 @@ The Competition_ID of competition_record is the foreign key of Competition_ID of
|
||||
###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501
|
||||
|
||||
EXPECTED_LORA_OUTPUT = [
|
||||
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
|
||||
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
|
||||
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
|
||||
"SELECT avg(Working_Horses) FROM farm WHERE Total_Horses > 5000",
|
||||
"SELECT max(Cows) , min(Cows) FROM farm",
|
||||
"SELECT max(Cows) , min(Cows) FROM farm",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -679,12 +679,12 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
# (num_experts,rank,input_size)
|
||||
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
|
||||
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
|
||||
# (output_size,num_experts,rank)
|
||||
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
|
||||
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
|
||||
# (output_size,rank,num_experts)
|
||||
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], -1, num_experts)
|
||||
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], -1, num_experts)
|
||||
# (num_experts,output_size,rank)
|
||||
w13_lora_b = w13_lora_b.permute(1, 0, 2)
|
||||
w2_lora_b = w2_lora_b.permute(1, 0, 2)
|
||||
w13_lora_b = w13_lora_b.permute(2, 0, 1)
|
||||
w2_lora_b = w2_lora_b.permute(2, 0, 1)
|
||||
|
||||
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
|
||||
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
|
||||
|
||||
Reference in New Issue
Block a user