Minor fixes for Mixtral (#2015)
This commit is contained in:
@@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -453,10 +453,6 @@ class MixtralForCausalLM(nn.Module):
|
||||
assert linear_method is None
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.tok_embeddings: Union[nn.Embedding, None] = None
|
||||
self.layers: nn.ModuleList = None
|
||||
self.output: Union[nn.Linear, None] = None
|
||||
self.sampler: Union[Sampler, None] = None
|
||||
self.tok_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@@ -492,6 +488,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
@@ -499,7 +496,6 @@ class MixtralForCausalLM(nn.Module):
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
next_tokens = self.sampler(self.output.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
Reference in New Issue
Block a user