change the timing of sorting logits (#1309)
This commit is contained in:
@@ -102,30 +102,24 @@ def _prune_hidden_states(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
last_token_indices = {t: [] for t in SamplingType}
|
last_token_indices = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, _ = seq_group
|
||||||
sampling_type = sampling_params.sampling_type
|
|
||||||
if i < input_metadata.num_prompts:
|
if i < input_metadata.num_prompts:
|
||||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = input_metadata.prompt_lens[i]
|
||||||
last_token_indices[sampling_type].append(start_idx + prompt_len -
|
last_token_indices.append(start_idx + prompt_len - 1)
|
||||||
1)
|
|
||||||
start_idx += prompt_len
|
start_idx += prompt_len
|
||||||
else:
|
else:
|
||||||
num_seqs = len(seq_ids)
|
num_seqs = len(seq_ids)
|
||||||
last_token_indices[sampling_type].extend(
|
last_token_indices.extend(range(start_idx, start_idx + num_seqs))
|
||||||
range(start_idx, start_idx + num_seqs))
|
|
||||||
start_idx += num_seqs
|
start_idx += num_seqs
|
||||||
|
|
||||||
all_last_token_indices = []
|
last_token_indices = torch.tensor(last_token_indices,
|
||||||
for sampling_type in SamplingType:
|
|
||||||
all_last_token_indices.extend(last_token_indices[sampling_type])
|
|
||||||
all_last_token_indices = torch.tensor(all_last_token_indices,
|
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
return hidden_states.index_select(0, all_last_token_indices)
|
return hidden_states.index_select(0, last_token_indices)
|
||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
def _get_penalties(
|
||||||
@@ -424,27 +418,26 @@ def _sample(
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
category_num_tokens = {t: 0 for t in SamplingType}
|
start_idx = 0
|
||||||
|
categorized_seq_ids = {t: [] for t in SamplingType}
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
sampling_type = sampling_params.sampling_type
|
sampling_type = sampling_params.sampling_type
|
||||||
categorized_seq_group_ids[sampling_type].append(i)
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
num_seqs = len(seq_ids)
|
num_seqs = len(seq_ids)
|
||||||
category_num_tokens[sampling_type] += num_seqs
|
categorized_seq_ids[sampling_type].extend(
|
||||||
|
range(start_idx, start_idx + num_seqs))
|
||||||
|
start_idx += num_seqs
|
||||||
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
||||||
category_start_idx = 0
|
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||||
num_tokens = category_num_tokens[sampling_type]
|
num_tokens = len(categorized_seq_ids[sampling_type])
|
||||||
if num_tokens == 0:
|
if num_tokens == 0:
|
||||||
continue
|
continue
|
||||||
category_logprobs = logprobs[category_start_idx:category_start_idx +
|
category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
|
||||||
num_tokens]
|
category_probs = probs[categorized_seq_ids[sampling_type]]
|
||||||
category_probs = probs[category_start_idx:category_start_idx +
|
|
||||||
num_tokens]
|
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||||
elif sampling_type == SamplingType.RANDOM:
|
elif sampling_type == SamplingType.RANDOM:
|
||||||
@@ -497,6 +490,5 @@ def _sample(
|
|||||||
sample_idx += num_parent_seqs
|
sample_idx += num_parent_seqs
|
||||||
result_idx += num_results
|
result_idx += num_results
|
||||||
assert sample_idx == num_tokens
|
assert sample_idx == num_tokens
|
||||||
category_start_idx += num_tokens
|
|
||||||
|
|
||||||
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
||||||
|
|||||||
Reference in New Issue
Block a user