[Bugfix] Fix hybrid model tests (#17182)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-26 06:14:37 +08:00
committed by GitHub
parent 48cb2109b6
commit 43faa0461a
3 changed files with 159 additions and 535 deletions

View File

@@ -531,7 +531,10 @@ class HfRunner:
for _, hidden_state in enumerate(hidden_states):
last_hidden_states = hidden_state[-1][0]
logits = torch.matmul(
last_hidden_states.to(output_embeddings.weight.device),
last_hidden_states.to(
device=output_embeddings.weight.device,
dtype=output_embeddings.weight.dtype,
),
output_embeddings.weight.t(),
)
if getattr(output_embeddings, "bias", None) is not None: