Optimize MQA Kernel (#452)

This commit is contained in:
Zhuohan Li
2023-07-14 20:06:40 -04:00
committed by GitHub
parent dbed69058c
commit 96853af5a8
5 changed files with 84 additions and 72 deletions

View File

@@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
import numpy as np
from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata
@@ -55,10 +54,12 @@ class GPTBigCodeAttention(nn.Module):
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.num_kv_heads = 1 if config.multi_query else self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
self.hidden_size + 2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
@@ -69,7 +70,8 @@ class GPTBigCodeAttention(nn.Module):
perform_initialization=False)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
scale=self.scale,
num_kv_heads=self.num_kv_heads)
def forward(
self,
@@ -79,7 +81,8 @@ class GPTBigCodeAttention(nn.Module):
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
@@ -263,36 +266,6 @@ class GPTBigCodeForCausalLM(nn.Module):
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
def _expand_mqa_mha(qkv_array, n_head, head_dim):
"""manipulates along axis=0 from MQA to MHA
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
TODO: this function is no longer needed once vllm supports MQA.
"""
qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim
# pylint: disable=unbalanced-tuple-unpacking
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
axis=0)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights
else:
replication = n_head # biases
# replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
@@ -300,30 +273,27 @@ class GPTBigCodeForCausalLM(nn.Module):
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
# Else, keep the weights as is for multi-query attention
loaded_weight = torch.cat([wq, wk, wv], dim=0)
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,