[bugfix][quantization] fix quark qwen3 kv_cache quantization (#30308)
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
This commit is contained in:
@@ -403,6 +403,7 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@@ -505,6 +506,19 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
expert_params_mapping = self.get_expert_mapping()
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
if self.quant_config is not None and (
|
||||||
|
scale_name := self.quant_config.get_cache_scale(name)
|
||||||
|
):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
assert loaded_weight.numel() == 1, (
|
||||||
|
f"KV scale numel {loaded_weight.numel()} != 1"
|
||||||
|
)
|
||||||
|
loaded_weight = loaded_weight.squeeze()
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
Reference in New Issue
Block a user