[Hardware][Neuron] Refactor neuron support (#3471)

This commit is contained in:
Zhuohan Li
2024-03-21 18:22:17 -07:00
committed by GitHub
parent ea5f14e6ff
commit e90fc21f2e
33 changed files with 615 additions and 549 deletions

View File

@@ -4,8 +4,6 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.utils import is_neuron
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -23,7 +21,8 @@ class LogitsProcessor(nn.Module):
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0) -> None:
scale: Optional[float] = 1.0,
logits_as_input: bool = False) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
@@ -31,8 +30,8 @@ class LogitsProcessor(nn.Module):
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
@@ -43,7 +42,7 @@ class LogitsProcessor(nn.Module):
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_hidden_states:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,