Support MPT with GQA (#1938)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.head_dim = self.d_model // self.total_num_heads
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
if "kv_n_heads" in config.attn_config:
|
||||
self.total_num_kv_heads = config.attn_config['kv_n_heads']
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
assert not config.attn_config["prefix_lm"]
|
||||
assert config.attn_config["alibi"]
|
||||
|
||||
@@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
|
||||
self.d_model,
|
||||
self.d_model // self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=not config.no_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
@@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
if self.total_num_kv_heads >= tp_world_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_world_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_world_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
# Create the alibi slopes and slice them.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
@@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes)
|
||||
alibi_slopes=alibi_slopes,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.qk_ln:
|
||||
q = self.q_ln(q)
|
||||
k = self.k_ln(k)
|
||||
|
||||
Reference in New Issue
Block a user