Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights."""
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
@@ -34,35 +35,48 @@ from transformers.models.glm4_moe import Glm4MoeConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Glm4MoeMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -74,19 +88,24 @@ class Glm4MoeMLP(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -97,7 +116,6 @@ class Glm4MoeMLP(nn.Module):
|
||||
|
||||
|
||||
class Glm4MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
@@ -116,8 +134,10 @@ class Glm4MoE(nn.Module):
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
# NOTE In the transformers implementation, the gate isn't an nn.Linear,
|
||||
# so we cannot use ReplicatedLinear here.
|
||||
# See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260
|
||||
@@ -128,7 +148,8 @@ class Glm4MoE(nn.Module):
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, dtype=torch.float32))
|
||||
torch.empty(config.n_routed_experts, dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -137,18 +158,16 @@ class Glm4MoE(nn.Module):
|
||||
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||
self.physical_expert_end = (
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = Glm4MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
@@ -195,7 +214,8 @@ class Glm4MoE(nn.Module):
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
@@ -204,27 +224,27 @@ class Glm4MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
|
||||
fused_moe_out = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
fused_moe_out = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output, final_hidden_states = fused_moe_out
|
||||
assert shared_output is not None
|
||||
final_hidden_states = \
|
||||
final_hidden_states * self.routed_scaling_factor\
|
||||
+ shared_output
|
||||
final_hidden_states = (
|
||||
final_hidden_states * self.routed_scaling_factor + shared_output
|
||||
)
|
||||
else:
|
||||
final_hidden_states = fused_moe_out * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states
|
||||
)
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class Glm4MoeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
@@ -266,19 +286,23 @@ class Glm4MoeAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -311,10 +335,12 @@ class Glm4MoeAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q = self.q_norm(q.reshape(-1, self.num_heads,
|
||||
self.head_dim)).reshape(q.shape)
|
||||
k = self.k_norm(k.reshape(-1, self.num_kv_heads,
|
||||
self.head_dim)).reshape(k.shape)
|
||||
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
|
||||
q.shape
|
||||
)
|
||||
k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape(
|
||||
k.shape
|
||||
)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
@@ -323,7 +349,6 @@ class Glm4MoeAttention(nn.Module):
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
@@ -336,11 +361,10 @@ class Glm4MoeDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
131072)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
layer_idx = int(prefix.split(sep=".")[-1])
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Glm4MoeAttention(
|
||||
@@ -360,8 +384,10 @@ class Glm4MoeDecoderLayer(nn.Module):
|
||||
use_qk_norm=config.use_qk_norm,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace):
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
):
|
||||
self.mlp = Glm4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@@ -369,16 +395,18 @@ class Glm4MoeDecoderLayer(nn.Module):
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.mlp = Glm4MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
def forward(
|
||||
@@ -391,12 +419,9 @@ class Glm4MoeDecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
@@ -407,9 +432,9 @@ class Glm4MoeDecoderLayer(nn.Module):
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
}
|
||||
)
|
||||
class Glm4MoeModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -423,9 +448,8 @@ class Glm4MoeModel(nn.Module):
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens"
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
@@ -438,15 +462,16 @@ class Glm4MoeModel(nn.Module):
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -473,27 +498,26 @@ class Glm4MoeModel(nn.Module):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
"residual": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
@@ -502,10 +526,10 @@ class Glm4MoeModel(nn.Module):
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
num_experts=self.config.n_routed_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -522,7 +546,7 @@ class Glm4MoeModel(nn.Module):
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -532,7 +556,7 @@ class Glm4MoeModel(nn.Module):
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
@@ -567,14 +591,17 @@ class Glm4MoeModel(nn.Module):
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
@@ -598,8 +625,9 @@ class Glm4MoeModel(nn.Module):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
@@ -627,24 +655,26 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Glm4MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = Glm4MoeModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
self.expert_weights = []
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (config.num_hidden_layers -
|
||||
config.first_k_dense_replace)
|
||||
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
|
||||
self.num_expert_groups = config.n_group
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
@@ -695,8 +725,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@@ -706,8 +737,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@@ -715,13 +745,14 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
config: Glm4MoeConfig, weight_name: str
|
||||
) -> Optional[int]:
|
||||
if hasattr(config, "num_nextn_predict_layers") and (
|
||||
config.num_nextn_predict_layers > 0
|
||||
):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if f"layers.{layer_idx+i}." in weight_name:
|
||||
if f"layers.{layer_idx + i}." in weight_name:
|
||||
return layer_idx + i
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user