Optimize MQA Kernel (#452)
This commit is contained in:
@@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from transformers import GPTBigCodeConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
@@ -55,10 +54,12 @@ class GPTBigCodeAttention(nn.Module):
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.num_kv_heads = 1 if config.multi_query else self.num_heads
|
||||
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
self.hidden_size + 2 * self.kv_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
@@ -69,7 +70,8 @@ class GPTBigCodeAttention(nn.Module):
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale)
|
||||
scale=self.scale,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -79,7 +81,8 @@ class GPTBigCodeAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
|
||||
dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
@@ -263,36 +266,6 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
def _expand_mqa_mha(qkv_array, n_head, head_dim):
|
||||
"""manipulates along axis=0 from MQA to MHA
|
||||
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
|
||||
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
|
||||
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
|
||||
|
||||
TODO: this function is no longer needed once vllm supports MQA.
|
||||
"""
|
||||
qkv_array = qkv_array.numpy()
|
||||
|
||||
dims_q = n_head * head_dim
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
|
||||
axis=0)
|
||||
# q is fine, but k & v have not replicated shape along the first
|
||||
# axis as long as MQA is not nativly supported, increase memory
|
||||
# and replicated (head_dim, hidden_dim) to
|
||||
# (n_heads * head_dim, hidden_dim)
|
||||
if k.ndim == 2 and v.ndim == 2:
|
||||
replication = (n_head, 1) # weights
|
||||
else:
|
||||
replication = n_head # biases
|
||||
# replicate n_head times for q, v
|
||||
k, v = np.tile(k, replication), np.tile(v, replication)
|
||||
# concat q, k, v along the first axis
|
||||
# (n_heads * head_dim, hidden_dim)
|
||||
# to (3 * n_heads * head_dim, hidden_dim)
|
||||
qkv_array = np.concatenate((q, k, v), axis=0)
|
||||
return torch.from_numpy(qkv_array)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of
|
||||
@@ -300,30 +273,27 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
total_num_kv_heads = (1 if self.config.multi_query else
|
||||
total_num_heads)
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
total_kv_size = head_size * total_num_kv_heads
|
||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
||||
n_head=total_num_heads,
|
||||
head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
||||
n_head=total_num_heads,
|
||||
head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
wq, wk, wv = torch.split(
|
||||
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||
dim=0)
|
||||
|
||||
wq = wq[head_size * head_start:head_size * head_end]
|
||||
if not self.config.multi_query:
|
||||
# Split the heads when using normal multi-head attention
|
||||
wk = wk[head_size * head_start:head_size * head_end]
|
||||
wv = wv[head_size * head_start:head_size * head_end]
|
||||
# Else, keep the weights as is for multi-query attention
|
||||
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
|
||||
Reference in New Issue
Block a user