Re-enable the 80 char line width limit (#3305)

This commit is contained in:
Zhuohan Li
2024-03-10 19:49:14 -07:00
committed by GitHub
parent 4b59f00e91
commit 2f8844ba08
67 changed files with 557 additions and 528 deletions

View File

@@ -52,7 +52,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, )
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -81,7 +82,8 @@ class SwiGLU(nn.Module):
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
@@ -94,11 +96,12 @@ class OlmoAttention(nn.Module):
self.config = config
self.hidden_size = config.d_model
assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = self.config.n_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads
# Layer norms.
@@ -158,7 +161,8 @@ class OlmoAttention(nn.Module):
class OlmoMLP(nn.Module):
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
This is the MLP block where the output is computed as
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
@@ -217,7 +221,8 @@ class OlmoMLP(nn.Module):
class OlmoBlock(nn.Module):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""