[Model] Clean up MiniCPMV (#10751)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-29 12:47:06 +08:00
committed by GitHub
parent c83919c7a6
commit fa6ecb9aa7
7 changed files with 149 additions and 215 deletions

View File

@@ -242,7 +242,7 @@ class FusedMoE(torch.nn.Module):
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
loaded_weight: torch.Tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
@@ -261,7 +261,7 @@ class FusedMoE(torch.nn.Module):
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
loaded_weight: torch.Tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
@@ -274,7 +274,7 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
@@ -292,7 +292,7 @@ class FusedMoE(torch.nn.Module):
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
@@ -311,7 +311,7 @@ class FusedMoE(torch.nn.Module):
param_data[expert_id] = loaded_weight
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
if shard_id == "w2":
self._load_w2(shard_id=shard_id,