[Model] Remove the unnecessary dtype conversion in MiniCPM (#32523)
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
This commit is contained in:
@@ -300,10 +300,7 @@ class MiniCPMAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
orig_dtype = q.dtype
|
||||
q, k = q.float(), k.float()
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user