[Model] Cleanup: Remove redundant manual definition of make_empty_intermediate_tensors in GLM-4-MoE (#31869)
Signed-off-by: maang <maang_h@163.com>
This commit is contained in:
@@ -478,20 +478,6 @@ class Glm4MoeModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def make_empty_intermediate_tensors(
|
|
||||||
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
|
||||||
) -> IntermediateTensors:
|
|
||||||
return IntermediateTensors(
|
|
||||||
{
|
|
||||||
"hidden_states": torch.zeros(
|
|
||||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
|
||||||
),
|
|
||||||
"residual": torch.zeros(
|
|
||||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user