[Core] Refactor Attention Take 2 (#3462)
This commit is contained in:
@@ -42,8 +42,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
@@ -67,8 +66,6 @@ from vllm.sequence import SamplerOutput
|
||||
# this model must need this dependency
|
||||
from hf_olmo import OLMoConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
@@ -146,16 +143,15 @@ class OlmoAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.attn_norm(hidden_states)
|
||||
qkv, _ = self.att_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.config.rope:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.attn_out(attn_output)
|
||||
return output
|
||||
|
||||
@@ -241,12 +237,12 @@ class OlmoBlock(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Attention block.
|
||||
og_x = hidden_states
|
||||
x = self.attn(positions, hidden_states, kv_cache, input_metadata)
|
||||
x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||
x = x + og_x
|
||||
|
||||
# MLP block.
|
||||
@@ -296,8 +292,8 @@ class OlmoModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
||||
@@ -313,7 +309,7 @@ class OlmoModel(nn.Module):
|
||||
positions,
|
||||
x,
|
||||
kv_caches[block_idx],
|
||||
input_metadata,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
# Apply final layer norm.
|
||||
@@ -344,14 +340,14 @@ class OLMoForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user