Re-enable the 80 char line width limit (#3305)
This commit is contained in:
@@ -52,7 +52,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size, )
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@@ -81,7 +82,8 @@ class SwiGLU(nn.Module):
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
"""
|
||||
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
||||
This is the attention block where the output is computed as
|
||||
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
||||
(plus another skip connection).
|
||||
"""
|
||||
|
||||
@@ -94,11 +96,12 @@ class OlmoAttention(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.d_model
|
||||
assert config.d_model % config.n_heads == 0
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
)
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.total_num_heads = self.config.n_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
|
||||
# Layer norms.
|
||||
@@ -158,7 +161,8 @@ class OlmoAttention(nn.Module):
|
||||
|
||||
class OlmoMLP(nn.Module):
|
||||
"""
|
||||
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
||||
This is the MLP block where the output is computed as
|
||||
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
||||
(plus another skip connection).
|
||||
"""
|
||||
|
||||
@@ -217,7 +221,8 @@ class OlmoMLP(nn.Module):
|
||||
|
||||
class OlmoBlock(nn.Module):
|
||||
"""
|
||||
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
||||
This is a typical transformer block where the output is
|
||||
computed as ``MLP(LN(x + Attention(LN(x))))``
|
||||
(plus another skip connection).
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user