Order sequence ids + config update to support specifying custom quantization layers (#18279)
Signed-off-by: Elaine Zhao <elaineyz@amazon.com> Co-authored-by: Tailin Pan <tailinpa@amazon.com> Co-authored-by: Rishabh Rajesh <rishyraj@amazon.com> Co-authored-by: Yishan McNabb <yishanm@amazon.com> Co-authored-by: Patrick Lange <patlange@amazon.com> Co-authored-by: Maxwell Goldberg <mgld@amazon.com> Co-authored-by: Aakash Shetty <sheaak@amazon.com>
This commit is contained in:
@@ -87,16 +87,29 @@ class NeuronCausalLM(nn.Module):
|
||||
input_block_ids: torch.Tensor,
|
||||
sampling_params: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
|
||||
output = self.model(input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=positions,
|
||||
seq_ids=input_block_ids,
|
||||
seq_ids=sorted_input_block_ids,
|
||||
sampling_params=sampling_params)
|
||||
# on-device sampling
|
||||
if self.config.neuron_config.on_device_sampling_config:
|
||||
return output.hidden_states
|
||||
output = output.hidden_states
|
||||
else:
|
||||
return output.logits[:, -1, :]
|
||||
output = output.logits[:, -1, :]
|
||||
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
if input_block_ids.shape[0] != 1:
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
@@ -340,14 +353,26 @@ class NeuronSpeculationCausalLM(nn.Module):
|
||||
input_block_ids: torch.Tensor,
|
||||
sampling_params: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
|
||||
output = self.model(input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=positions,
|
||||
seq_ids=input_block_ids,
|
||||
seq_ids=sorted_input_block_ids,
|
||||
sampling_params=sampling_params)
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
|
||||
# CTX encoding
|
||||
if (positions[:, 0]).sum().item() == 0:
|
||||
return output.fused_outputs[0][:, 0:1]
|
||||
output = output.fused_outputs[0][:, 0:1]
|
||||
if input_block_ids.shape[0] != 1:
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
return output
|
||||
|
||||
# Fused Spec (Generation)
|
||||
accepted_tokens_with_padding = output.fused_outputs[0]
|
||||
@@ -362,6 +387,10 @@ class NeuronSpeculationCausalLM(nn.Module):
|
||||
-1) >= generated_token_counts
|
||||
accepted_tokens_with_padding[mask] = -1
|
||||
|
||||
if input_block_ids.shape[0] != 1:
|
||||
accepted_tokens_with_padding = torch.index_select(
|
||||
accepted_tokens_with_padding, 0, restored_indices)
|
||||
|
||||
return accepted_tokens_with_padding
|
||||
|
||||
def sample(
|
||||
@@ -416,6 +445,10 @@ class NeuronSpeculationCausalLM(nn.Module):
|
||||
draft_neuron_config.speculation_length = 0
|
||||
draft_neuron_config.trace_tokengen_model = True
|
||||
draft_neuron_config.enable_fused_speculation = False
|
||||
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
|
||||
None):
|
||||
draft_neuron_config.modules_to_not_convert = (
|
||||
draft_neuron_config.draft_model_modules_to_not_convert)
|
||||
if config.neuron_config.enable_eagle_speculation:
|
||||
draft_neuron_config.is_eagle_draft = True
|
||||
draft_neuron_config.sequence_parallel_enabled = False
|
||||
|
||||
Reference in New Issue
Block a user