diff --git a/examples/offline_inference/extract_hidden_states.py b/examples/offline_inference/extract_hidden_states.py index 61299101c..551f13761 100644 --- a/examples/offline_inference/extract_hidden_states.py +++ b/examples/offline_inference/extract_hidden_states.py @@ -54,5 +54,5 @@ with tempfile.TemporaryDirectory() as tmpdirname: print("Extracted token ids:", token_ids) # Matches prompt token ids print( "Extracted hidden states shape:", hidden_states.shape - ) # [num_hidden_layers, prompt len, hidden size] + ) # [prompt len, num_hidden_layers, hidden size] print("Extracted hidden states:", hidden_states)