diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index c08dc9c46..750ccd7c9 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -95,7 +95,6 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras): max_model_len=1024, enable_lora=True, max_loras=2, - max_lora_rank=8, max_num_seqs=2, max_num_batched_tokens=2048, tensor_parallel_size=2, diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 99242806c..f0bcca915 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -428,9 +428,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): current_lora_rank = w13_lora_a.shape[1] assert current_lora_rank % self.tp_size == 0 # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim. - sliced_rank = current_lora_rank // self.tp_size - start_idx = self.tp_rank * sliced_rank - end_idx = (self.tp_rank + 1) * sliced_rank + shard_size = self.w13_lora_a_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size return w13_lora_a[:, start_idx:end_idx, :] def _slice_w13_b(self, w13_lora_b: torch.Tensor): @@ -465,11 +465,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): return w2_lora_b # Based on S-LoRA, we slice W2 B along the hidden_size dim. # w2_lora_b shape (num_experts,output_size,rank) - current_lora_size = w2_lora_b.shape[1] + shard_size = self.w2_lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size - sliced_size = current_lora_size // self.tp_size - start_idx = self.tp_rank * sliced_size - end_idx = (self.tp_rank + 1) * sliced_size return w2_lora_b[:, start_idx:end_idx, :] def reset_lora(self, index: int):