[Feature] minicpm eagle support (#18943)

Signed-off-by: huangyuxiang03 <huangyx0321@gmail.com>
Co-authored-by: huangyuxiang03 <huangyx0321@gmail.com>
This commit is contained in:
Shawn Huang
2025-05-30 21:45:56 +08:00
committed by GitHub
parent 43ff405b90
commit e1fadf1197
4 changed files with 399 additions and 2 deletions

View File

@@ -242,6 +242,7 @@ class MiniCPMAttention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
@@ -444,6 +445,7 @@ class MiniCPMModel(nn.Module):
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
@@ -567,7 +569,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
inputs_embeds) / self.scale_width
return hidden_states
def compute_logits(
@@ -575,7 +577,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states / self.scale_width
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits