[Qwen] Add fp8 checkpoint support for qwen3-next. (#25079)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He
2025-09-18 16:16:04 +08:00
committed by GitHub
parent 350c94deb3
commit ef7eefe17a
2 changed files with 22 additions and 21 deletions

View File

@@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False)
return_bias=False,
quant_config=quant_config,
prefix=f'{prefix}.fc')
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
@@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers))
self.make_empty_intermediate_tensors = (
@@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
self.config = config
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
prefix, "mtp"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
config.hidden_size,