Minor fixes for Mixtral (#2015)

This commit is contained in:
Woosuk Kwon
2023-12-11 09:16:15 -08:00
committed by GitHub
parent b5f882cc98
commit 4ff0203987
2 changed files with 5 additions and 6 deletions

View File

@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import numpy as np
@@ -453,10 +453,6 @@ class MixtralForCausalLM(nn.Module):
assert linear_method is None
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.tok_embeddings: Union[nn.Embedding, None] = None
self.layers: nn.ModuleList = None
self.output: Union[nn.Linear, None] = None
self.sampler: Union[Sampler, None] = None
self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
@@ -492,6 +488,7 @@ class MixtralForCausalLM(nn.Module):
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
def sample(
@@ -499,7 +496,6 @@ class MixtralForCausalLM(nn.Module):
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
hidden_states = self.norm(hidden_states)
next_tokens = self.sampler(self.output.weight, hidden_states,
sampling_metadata)
return next_tokens