[Qwen] Add fp8 checkpoint support for qwen3-next. (#25079)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
@@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
self.config.hidden_size,
|
||||
gather_output=True,
|
||||
bias=False,
|
||||
return_bias=False)
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.fc')
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
Qwen3NextDecoderLayer(
|
||||
@@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
|
||||
prefix=f'{prefix}.layers.{idx}',
|
||||
) for idx in range(self.num_mtp_layers))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
@@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
|
||||
self.config = config
|
||||
self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
prefix, "mtp"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user