fix lora moe sharding when rank < max_lora_rank (#31994)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -95,7 +95,6 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=8,
|
||||
max_num_seqs=2,
|
||||
max_num_batched_tokens=2048,
|
||||
tensor_parallel_size=2,
|
||||
|
||||
@@ -428,9 +428,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
current_lora_rank = w13_lora_a.shape[1]
|
||||
assert current_lora_rank % self.tp_size == 0
|
||||
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
|
||||
sliced_rank = current_lora_rank // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_rank
|
||||
end_idx = (self.tp_rank + 1) * sliced_rank
|
||||
shard_size = self.w13_lora_a_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
return w13_lora_a[:, start_idx:end_idx, :]
|
||||
|
||||
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
|
||||
@@ -465,11 +465,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
return w2_lora_b
|
||||
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
|
||||
# w2_lora_b shape (num_experts,output_size,rank)
|
||||
current_lora_size = w2_lora_b.shape[1]
|
||||
shard_size = self.w2_lora_b_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
sliced_size = current_lora_size // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_size
|
||||
end_idx = (self.tp_rank + 1) * sliced_size
|
||||
return w2_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
|
||||
Reference in New Issue
Block a user