This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -16,12 +15,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@abstractmethod
|
||||
@@ -206,182 +199,55 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
expert_id: int):
|
||||
param_data = param.data
|
||||
# for per tensor weight quantization
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
elif shard_id == "w2":
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int):
|
||||
# Load grouped weight scales for group quantization
|
||||
# or model weights
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
|
||||
shard_dim: int, shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int):
|
||||
# for per channel weight quantization
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_single_value(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, expert_id: int):
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
# Special case for fp8 scales.
|
||||
if getattr(param, "is_fp8_scale", False):
|
||||
self._load_fp8_scale(param.data, loaded_weight, weight_name,
|
||||
shard_id, expert_id)
|
||||
return
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: whether or not the parameter is transposed on disk
|
||||
# If transposed, the loaded weight will be transposed and the dim
|
||||
# to shard the loaded weight will be flipped.
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
shard_dim = ~shard_dim
|
||||
# If transposed, weight is saved as [input_dim, output_dim]
|
||||
# Otherwise, weight is saved as [output_dim, input_dim]
|
||||
# Default is not transposed/input dim is dim 1
|
||||
input_dim = getattr(param, "input_dim", 1)
|
||||
output_dim = getattr(param, "output_dim", 0)
|
||||
|
||||
# Case weight_scales
|
||||
if "weight_scale" in weight_name:
|
||||
# load the weight scaling based on the quantization scheme
|
||||
# supported weight scales can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
if shard_id == "w2":
|
||||
shard_dim = input_dim
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
elif shard_id in ("w1", "w3"):
|
||||
shard_dim = output_dim
|
||||
shard_size = expert_data.shape[output_dim] // 2
|
||||
offset = shard_size * tp_rank
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
|
||||
|
||||
if "weight_shape" in weight_name:
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case input scale
|
||||
if "input_scale" in weight_name:
|
||||
# Note: input_scale loading is only supported for fp8
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
return
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
elif shard_id == "w3":
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
elif shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
||||
|
||||
@staticmethod
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user