add support for --fully-sharded-loras in fused_moe (#28761)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@@ -205,6 +206,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
shrink_config, ## pass the shrink config
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
fully_sharded=self.fully_sharded,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
@@ -250,7 +252,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
||||
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
||||
intermediate_cache3 = args[0]
|
||||
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
|
||||
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
|
||||
|
||||
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
|
||||
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
intermediate_cache3,
|
||||
intermediate_cache2,
|
||||
@@ -266,6 +271,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
True,
|
||||
fully_sharded=self.fully_sharded,
|
||||
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
@@ -294,6 +301,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
self.fully_sharded = lora_config.fully_sharded_loras
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
@@ -303,7 +311,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
lora_config.max_lora_rank
|
||||
if not self.fully_sharded
|
||||
else divide(lora_config.max_lora_rank, self.tp_size),
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
@@ -334,7 +344,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size,
|
||||
self.base_layer.hidden_size
|
||||
if not self.fully_sharded
|
||||
else divide(self.base_layer.hidden_size, self.tp_size),
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
@@ -345,7 +357,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
lora_config.max_lora_rank
|
||||
if not self.fully_sharded
|
||||
else divide(lora_config.max_lora_rank, self.tp_size),
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
@@ -419,6 +433,20 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
|
||||
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
|
||||
|
||||
if self.fully_sharded:
|
||||
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
|
||||
# and W2 B along the hidden_size dim.
|
||||
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
|
||||
w13_start_idx = self.tp_rank * w13_shard_size
|
||||
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
|
||||
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
|
||||
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
|
||||
|
||||
w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
|
||||
w2_start_idx = self.tp_rank * w2_shard_size
|
||||
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
|
||||
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
|
||||
|
||||
self.w1_lora_a_stacked[
|
||||
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
||||
].copy_(w1_lora_a, non_blocking=True)
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@@ -311,6 +315,7 @@ def _fused_moe_lora_expand(
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
b_ptr = _get_ptr(lora_b_stacked, device)
|
||||
K = max_lora_rank
|
||||
@@ -380,7 +385,7 @@ def _fused_moe_lora_expand(
|
||||
**expand_config,
|
||||
)
|
||||
for i in range(num_slices):
|
||||
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
|
||||
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -416,6 +421,8 @@ def _fused_moe_lora(
|
||||
expand_num_stages: int,
|
||||
expand_split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
fully_sharded: bool = False,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||
assert (
|
||||
@@ -430,7 +437,6 @@ def _fused_moe_lora(
|
||||
== expert_ids.shape[0]
|
||||
== num_tokens_post_padded.shape[0]
|
||||
)
|
||||
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
|
||||
assert output.shape[0] == topk_weights.shape[0]
|
||||
assert top_k_num == topk_weights.shape[1]
|
||||
device = qcurr_hidden_states.device
|
||||
@@ -480,6 +486,19 @@ def _fused_moe_lora(
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
if fully_sharded:
|
||||
if max_lora_rank == w1_lora_b_stacked.shape[-1]:
|
||||
a_intermediate_cache1 = tensor_model_parallel_all_reduce(
|
||||
a_intermediate_cache1
|
||||
)
|
||||
else:
|
||||
a_intermediate_cache1 = tensor_model_parallel_all_gather(
|
||||
a_intermediate_cache1
|
||||
)
|
||||
|
||||
# reset max_lora_rank to the full rank after allgather
|
||||
max_lora_rank = a_intermediate_cache1.shape[-1]
|
||||
|
||||
_fused_moe_lora_expand(
|
||||
output,
|
||||
a_intermediate_cache1,
|
||||
@@ -510,6 +529,7 @@ def _fused_moe_lora(
|
||||
expand_num_stages,
|
||||
expand_split_k,
|
||||
mul_routed_weight,
|
||||
offset,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -483,6 +483,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
expand_config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
fully_sharded: bool = False,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of
|
||||
|
||||
@@ -375,6 +375,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
expand_config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
fully_sharded: bool = False,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||
@@ -408,4 +410,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
expand_config.get("NUM_STAGES", 3),
|
||||
expand_config.get("SPLIT_K", 1),
|
||||
mul_routed_weight,
|
||||
fully_sharded,
|
||||
offset,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user