[Bugfix] Fix hybrid model tests (#17182)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -531,7 +531,10 @@ class HfRunner:
|
||||
for _, hidden_state in enumerate(hidden_states):
|
||||
last_hidden_states = hidden_state[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states.to(output_embeddings.weight.device),
|
||||
last_hidden_states.to(
|
||||
device=output_embeddings.weight.device,
|
||||
dtype=output_embeddings.weight.dtype,
|
||||
),
|
||||
output_embeddings.weight.t(),
|
||||
)
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
|
||||
Reference in New Issue
Block a user