Support MPT with GQA (#1938)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Megha Agarwal
2023-12-12 10:16:05 -08:00
committed by GitHub
parent 7e1b21daac
commit 6428f1d051
2 changed files with 28 additions and 6 deletions

View File

@@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
if "kv_n_heads" in config.attn_config:
self.total_num_kv_heads = config.attn_config['kv_n_heads']
else:
self.total_num_kv_heads = self.total_num_heads
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
@@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
linear_method=linear_method,
)
@@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
@@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
def forward(
self,
@@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)