[Misc] Fused MoE Marlin support for GPTQ (#8217)

This commit is contained in:
Dipika Sikka
2024-09-09 23:02:52 -04:00
committed by GitHub
parent c7cb5c3335
commit 6cd5e5b07e
19 changed files with 912 additions and 204 deletions

View File

@@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
# compressed-tensors represents weights on disk which are flipped
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsMoEMethod") else loaded_weight
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
@@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
# is_transposed: whether or not the parameter is transposed on disk
# If transposed, the loaded weight will be transposed and the dim
# to shard the loaded weight will be flipped.
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
loaded_weight = loaded_weight.t().contiguous()
shard_dim = ~shard_dim
# Case weight_scales
if "weight_scale" in weight_name:
# load the weight scaling based on the quantization scheme
# supported weight scales can be found in
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
# Case weight scales and zero_points
if ("scale" in weight_name or "zero" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
@@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return
# Case weight_shape
if "weight_shape" in weight_name:
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case input scale
if "input_scale" in weight_name:
# Note: input_scale loading is only supported for fp8
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
# only required by compressed-tensors
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
@@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
param_data[expert_id] = loaded_weight