[Core] Optimizing cross-attention QKVParallelLinear computation (#12325)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal> Co-authored-by: NickLucche <nick@nlucches-4xa100.c.openshift-330514.internal>
This commit is contained in:
@@ -1227,3 +1227,98 @@ class RowParallelLinear(LinearBase):
|
||||
s += f", tp_size={self.tp_size}"
|
||||
s += f", reduce_results={self.reduce_results}"
|
||||
return s
|
||||
|
||||
|
||||
class QKVCrossParallelLinear(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
# Empty placeholders for loading as a single module.
|
||||
self.weight = torch.nn.Parameter()
|
||||
set_weight_attrs(self.weight, {
|
||||
"weight_loader": self.weight_loader_weight,
|
||||
})
|
||||
# Use a dictionary to avoid submodules parameters auto-registration:
|
||||
# drop-in replacement for a `QKVParallelLinear` module.
|
||||
self.proj = dict()
|
||||
self.proj["q_proj_decoder"] = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=total_num_heads * head_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
prefix=f"{prefix}.q_proj_decoder")
|
||||
|
||||
self.proj["kv_proj_encoder"] = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=head_size,
|
||||
total_num_heads=0,
|
||||
total_num_kv_heads=total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
prefix=f"{prefix}.kv_proj_encoder")
|
||||
|
||||
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
|
||||
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
|
||||
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter()
|
||||
set_weight_attrs(self.bias, {
|
||||
"weight_loader": self.weight_loader_bias,
|
||||
})
|
||||
|
||||
@property
|
||||
def q_proj_decoder(self):
|
||||
return self.proj["q_proj_decoder"]
|
||||
|
||||
@property
|
||||
def kv_proj_encoder(self):
|
||||
return self.proj["kv_proj_encoder"]
|
||||
|
||||
def forward(self, decoder_hidden_states, encoder_hidden_states):
|
||||
q, _ = self.q_proj_decoder(decoder_hidden_states)
|
||||
if encoder_hidden_states is None:
|
||||
# Encoder KV already cached.
|
||||
k = None
|
||||
v = None
|
||||
else:
|
||||
# Prefill phase, encoder KV cached here.
|
||||
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
|
||||
# Split kv in half
|
||||
k, v = kv_enc.split(self.kv_size, dim=-1)
|
||||
return q, k, v
|
||||
|
||||
def weight_loader_weight(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
|
||||
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
|
||||
else self.kv_proj_encoder.weight
|
||||
param.weight_loader(
|
||||
param,
|
||||
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||
param, loaded_weight, loaded_shard_id)
|
||||
|
||||
def weight_loader_bias(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
|
||||
else self.kv_proj_encoder.bias
|
||||
param.weight_loader(
|
||||
param,
|
||||
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||
param, loaded_weight, loaded_shard_id)
|
||||
Reference in New Issue
Block a user