[Misc][Quark] Upstream Quark format to VLLM (#10765)
Signed-off-by: kewang-xlnx <kewang@xilinx.com> Signed-off-by: kewang2 <kewang2@amd.com> Co-authored-by: kewang2 <kewang2@amd.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -83,7 +83,7 @@ class DbrxExperts(FusedMoE):
|
||||
|
||||
# Define custom weight loader for dbrx model
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str):
|
||||
weight_name: str, param_name: str):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
@@ -91,25 +91,37 @@ class DbrxExperts(FusedMoE):
|
||||
# DBRX uses GLU for each experts.
|
||||
# GLU has 3 linear layers: w1, v1 and w2.
|
||||
if weight_name.endswith("w1"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
)
|
||||
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
||||
if param_name.endswith("weight"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
)
|
||||
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
||||
elif param_name.endswith("weight_scale"):
|
||||
param_data[:, 0] = loaded_weight
|
||||
else:
|
||||
param_data = loaded_weight
|
||||
if weight_name.endswith("v1"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
)
|
||||
param_data[:,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[:,
|
||||
shard, :]
|
||||
if param_name.endswith("weight"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
)
|
||||
param_data[:, shard_size:2 *
|
||||
shard_size, :] = loaded_weight[:, shard, :]
|
||||
elif param_name.endswith("weight_scale"):
|
||||
param_data[:, 1] = loaded_weight
|
||||
else:
|
||||
param_data[:] = loaded_weight
|
||||
if weight_name.endswith("w2"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
).transpose(1, 2)
|
||||
param_data[:] = loaded_weight[:, :, shard]
|
||||
if param_name.endswith("weight"):
|
||||
loaded_weight = torch.reshape(
|
||||
loaded_weight,
|
||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||
).transpose(1, 2)
|
||||
param_data[:] = loaded_weight[:, :, shard]
|
||||
else:
|
||||
param_data[:] = loaded_weight
|
||||
|
||||
|
||||
class DbrxMoE(nn.Module):
|
||||
@@ -430,14 +442,29 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
|
||||
expert_params_mapping = [(
|
||||
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
|
||||
"w13" if weight_name in ["w1", "v1"] else "w2",
|
||||
f"mlp.{weight_name}",
|
||||
) for weight_name in ["w1", "v1", "w2"]]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache scales for quark and
|
||||
# compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if name.endswith(("w1", "w2", "v1")):
|
||||
name = name + "_weight"
|
||||
for param_name, weight_name in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -446,8 +473,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, weight_name)
|
||||
weight_loader(param, loaded_weight, weight_name, name)
|
||||
break
|
||||
|
||||
else:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
@@ -456,6 +484,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Reference in New Issue
Block a user