[Misc] Fused MoE Marlin support for GPTQ (#8217)
This commit is contained in:
@@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
||||
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
|
||||
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
else:
|
||||
assert shard_id in ("w1", "w3")
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
|
||||
# compressed-tensors represents weights on disk which are flipped
|
||||
loaded_weight = loaded_weight.t().contiguous() if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsMoEMethod") else loaded_weight
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
@@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: whether or not the parameter is transposed on disk
|
||||
# If transposed, the loaded weight will be transposed and the dim
|
||||
# to shard the loaded weight will be flipped.
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
shard_dim = ~shard_dim
|
||||
|
||||
# Case weight_scales
|
||||
if "weight_scale" in weight_name:
|
||||
# load the weight scaling based on the quantization scheme
|
||||
# supported weight scales can be found in
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case g_idx
|
||||
if "g_idx" in weight_name:
|
||||
self._load_g_idx(shard_dim=0,
|
||||
shard_id=shard_id,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
return
|
||||
|
||||
# Case weight scales and zero_points
|
||||
if ("scale" in weight_name or "zero" in weight_name):
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
@@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return
|
||||
|
||||
# Case weight_shape
|
||||
if "weight_shape" in weight_name:
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case input scale
|
||||
if "input_scale" in weight_name:
|
||||
# Note: input_scale loading is only supported for fp8
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
# only required by compressed-tensors
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
@@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
Reference in New Issue
Block a user