Fix 1D query issue from _prune_hidden_states (#3539)
This commit is contained in:
@@ -77,7 +77,6 @@ def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
return hidden_states.index_select(0,
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user