Revert "[Kernel] Expand MoE weight loading + Add Fused Marlin MoE Kernel (#7527)" (#7764)

This commit is contained in:
Michael Goin
2024-08-21 23:42:14 -04:00
committed by GitHub
parent cde9183b40
commit aae74ef95c
15 changed files with 84 additions and 2374 deletions

View File

@@ -1,5 +1,4 @@
from abc import abstractmethod
from enum import Enum
from typing import List, Optional, Tuple
import torch
@@ -16,12 +15,6 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
@@ -206,182 +199,55 @@ class FusedMoE(torch.nn.Module):
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
# Special case for fp8 scales.
if getattr(param, "is_fp8_scale", False):
self._load_fp8_scale(param.data, loaded_weight, weight_name,
shard_id, expert_id)
return
expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
# is_transposed: whether or not the parameter is transposed on disk
# If transposed, the loaded weight will be transposed and the dim
# to shard the loaded weight will be flipped.
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
loaded_weight = loaded_weight.t().contiguous()
shard_dim = ~shard_dim
# If transposed, weight is saved as [input_dim, output_dim]
# Otherwise, weight is saved as [output_dim, input_dim]
# Default is not transposed/input dim is dim 1
input_dim = getattr(param, "input_dim", 1)
output_dim = getattr(param, "output_dim", 0)
# Case weight_scales
if "weight_scale" in weight_name:
# load the weight scaling based on the quantization scheme
# supported weight scales can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
if shard_id == "w2":
shard_dim = input_dim
shard_size = expert_data.shape[shard_dim]
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
elif shard_id in ("w1", "w3"):
shard_dim = output_dim
shard_size = expert_data.shape[output_dim] // 2
offset = shard_size * tp_rank
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
if "weight_shape" in weight_name:
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case input scale
if "input_scale" in weight_name:
# Note: input_scale loading is only supported for fp8
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
expert_data.copy_(loaded_weight)
# w3, up_proj: Load into second logical weight of w13.
elif shard_id == "w3":
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
# w2, down_proj: Load into only logical weight of w2.
elif shard_id == "w2":
expert_data.copy_(loaded_weight)
else:
raise ValueError(
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
@staticmethod
def select_experts(hidden_states: torch.Tensor,