Fix 1D query issue from _prune_hidden_states (#3539)

This commit is contained in:
SangBin Cho
2024-03-21 17:49:06 +09:00
committed by GitHub
parent 6ebd02bdef
commit 3bbff9e5ab

View File

@@ -77,7 +77,6 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)