[Model] Remove the unnecessary dtype conversion in MiniCPM (#32523)

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
This commit is contained in:
Canlin Guo
2026-01-18 16:07:28 +08:00
committed by GitHub
parent 963dc0b865
commit fe36bf5e80

View File

@@ -300,10 +300,7 @@ class MiniCPMAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
orig_dtype = q.dtype
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output