fix lora moe sharding when rank < max_lora_rank (#31994)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack
2026-01-08 22:43:25 -08:00
committed by GitHub
parent 707b240d7e
commit bde38c11df
2 changed files with 6 additions and 8 deletions

View File

@@ -95,7 +95,6 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,

View File

@@ -428,9 +428,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
shard_size = self.w13_lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
@@ -465,11 +465,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
shard_size = self.w2_lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def reset_lora(self, index: int):