[Model] Clean up MiniCPMV (#10751)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -242,7 +242,7 @@ class FusedMoE(torch.nn.Module):
|
||||
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int):
|
||||
# Load grouped weight scales for group quantization
|
||||
# or model weights
|
||||
@@ -261,7 +261,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
|
||||
shard_dim: int, shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int):
|
||||
# for per channel weight quantization
|
||||
if shard_id == "w2":
|
||||
@@ -274,7 +274,7 @@ class FusedMoE(torch.nn.Module):
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
@@ -292,7 +292,7 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
@@ -311,7 +311,7 @@ class FusedMoE(torch.nn.Module):
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
||||
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
|
||||
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_id=shard_id,
|
||||
|
||||
Reference in New Issue
Block a user