[Core] Refactor Attention Take 2 (#3462)

This commit is contained in:
Woosuk Kwon
2024-03-24 21:39:33 -07:00
committed by GitHub
parent b0dfa91dd7
commit 925f3332ca
47 changed files with 1268 additions and 1117 deletions

View File

@@ -28,9 +28,8 @@ import torch
from torch import nn
from transformers import Qwen2Config
from vllm.model_executor.input_metadata import InputMetadata
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
@@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Qwen2MLP(nn.Module):
@@ -147,14 +144,13 @@ class Qwen2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
@@ -197,8 +193,8 @@ class Qwen2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -212,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
# Fully Connected
@@ -248,8 +244,8 @@ class Qwen2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
@@ -259,7 +255,7 @@ class Qwen2Model(nn.Module):
positions,
hidden_states,
kv_caches[i],
input_metadata,
attn_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -315,11 +311,11 @@ class Qwen2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,