Revert "[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj" (#34683)
This commit is contained in:
@@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
self,
|
self,
|
||||||
param: Parameter,
|
param: Parameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
loaded_shard_id: int | None = None,
|
||||||
):
|
):
|
||||||
if isinstance(loaded_shard_id, tuple):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Shard id with multiple indices is not supported in weight_loader, "
|
|
||||||
"please use weight_loader_v2 instead."
|
|
||||||
)
|
|
||||||
# Special case for GGUF
|
# Special case for GGUF
|
||||||
# initialize GGUF param after we know the quantize type
|
# initialize GGUF param after we know the quantize type
|
||||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||||
@@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def _load_fused_module_from_checkpoint(
|
def _load_fused_module_from_checkpoint(
|
||||||
self,
|
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
|
||||||
param: BasevLLMParameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
output_sizes: list[int] | None = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handle special case for models where MLP layers are already
|
Handle special case for models where MLP layers are already
|
||||||
@@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
current_shard_offset = 0
|
current_shard_offset = 0
|
||||||
shard_offsets: list[tuple[int, int, int]] = []
|
shard_offsets: list[tuple[int, int, int]] = []
|
||||||
output_sizes = output_sizes or self.output_sizes
|
for i, output_size in enumerate(self.output_sizes):
|
||||||
for i, output_size in enumerate(output_sizes):
|
|
||||||
shard_offsets.append((i, current_shard_offset, output_size))
|
shard_offsets.append((i, current_shard_offset, output_size))
|
||||||
current_shard_offset += output_size
|
current_shard_offset += output_size
|
||||||
|
|
||||||
@@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
self,
|
self,
|
||||||
param: BasevLLMParameter,
|
param: BasevLLMParameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
loaded_shard_id: int | None = None,
|
||||||
):
|
):
|
||||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
if loaded_shard_id is None:
|
||||||
if isinstance(param, PerTensorScaleParameter):
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||||
return
|
return
|
||||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||||
return
|
return
|
||||||
output_sizes = (
|
|
||||||
[self.output_sizes[idx] for idx in loaded_shard_id]
|
|
||||||
if loaded_shard_id
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if isinstance(param, BlockQuantScaleParameter):
|
|
||||||
weight_block_size = getattr(self, "weight_block_size", None)
|
|
||||||
output_sizes = [
|
|
||||||
adjust_block_scale_shard(weight_block_size, size, 0)[0]
|
|
||||||
for size in (output_sizes or self.output_sizes)
|
|
||||||
]
|
|
||||||
# TODO: @dsikka - move to parameter.py
|
# TODO: @dsikka - move to parameter.py
|
||||||
self._load_fused_module_from_checkpoint(
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||||
param, loaded_weight, output_sizes=output_sizes
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
assert loaded_shard_id < len(self.output_sizes)
|
assert loaded_shard_id < len(self.output_sizes)
|
||||||
|
|||||||
@@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
|
CacheConfig,
|
||||||
|
ModelConfig,
|
||||||
|
SpeculativeConfig,
|
||||||
VllmConfig,
|
VllmConfig,
|
||||||
|
get_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
|
divide,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import (
|
from vllm.model_executor.layers.layernorm import (
|
||||||
GemmaRMSNorm as Qwen3_5RMSNorm,
|
GemmaRMSNorm as Qwen3_5RMSNorm,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
|
from vllm.model_executor.layers.layernorm import RMSNormGated
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
|
mamba_v2_sharded_weight_loader,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateCopyFunc,
|
MambaStateCopyFunc,
|
||||||
MambaStateCopyFuncCalculator,
|
MambaStateCopyFuncCalculator,
|
||||||
@@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
|
sharded_weight_loader,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.qwen3_5 import (
|
from vllm.transformers_utils.configs.qwen3_5 import (
|
||||||
Qwen3_5Config,
|
Qwen3_5Config,
|
||||||
@@ -80,6 +99,7 @@ from .interfaces import (
|
|||||||
)
|
)
|
||||||
from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||||
from .qwen3_next import (
|
from .qwen3_next import (
|
||||||
|
ChunkGatedDeltaRule,
|
||||||
Qwen3NextAttention,
|
Qwen3NextAttention,
|
||||||
Qwen3NextDecoderLayer,
|
Qwen3NextDecoderLayer,
|
||||||
Qwen3NextGatedDeltaNet,
|
Qwen3NextGatedDeltaNet,
|
||||||
@@ -119,31 +139,154 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
|
||||||
|
model_config: ModelConfig | None = None,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
speculative_config: SpeculativeConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super(Qwen3NextGatedDeltaNet, self).__init__()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_v_heads = config.linear_num_value_heads
|
||||||
|
self.num_k_heads = config.linear_num_key_heads
|
||||||
|
self.head_k_dim = config.linear_key_head_dim
|
||||||
|
self.head_v_dim = config.linear_value_head_dim
|
||||||
|
self.key_dim = self.head_k_dim * self.num_k_heads
|
||||||
|
self.value_dim = self.head_v_dim * self.num_v_heads
|
||||||
|
|
||||||
|
self.conv_kernel_size = config.linear_conv_kernel_dim
|
||||||
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
|
self.activation = config.hidden_act
|
||||||
|
self.act = ACT2FN[config.hidden_act]
|
||||||
|
self.layer_norm_epsilon = config.rms_norm_eps
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.speculative_config = speculative_config
|
||||||
|
self.num_spec = (
|
||||||
|
self.speculative_config.num_speculative_tokens
|
||||||
|
if self.speculative_config
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# QKV
|
||||||
|
self.conv_dim = self.key_dim * 2 + self.value_dim
|
||||||
|
self.conv1d = ColumnParallelLinear(
|
||||||
|
input_size=self.conv_kernel_size,
|
||||||
|
output_size=self.conv_dim,
|
||||||
|
bias=False,
|
||||||
|
prefix=f"{prefix}.conv1d",
|
||||||
|
)
|
||||||
|
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||||
|
|
||||||
|
self.in_proj_qkv = MergedColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
output_sizes=[self.key_dim, self.key_dim, self.value_dim],
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.in_proj_qkv",
|
||||||
|
)
|
||||||
|
self.in_proj_z = ColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
output_size=self.value_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.in_proj_z",
|
||||||
|
)
|
||||||
|
self.in_proj_b = ColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
output_size=self.num_v_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.in_proj_b",
|
||||||
|
)
|
||||||
|
self.in_proj_a = ColumnParallelLinear(
|
||||||
|
input_size=self.hidden_size,
|
||||||
|
output_size=self.num_v_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.in_proj_a",
|
||||||
|
)
|
||||||
|
|
||||||
|
query_key_settings = (self.key_dim, 0, False)
|
||||||
|
value_settings = (self.value_dim, 0, False)
|
||||||
|
|
||||||
|
delattr(self.conv1d.weight, "weight_loader")
|
||||||
|
set_weight_attrs(
|
||||||
|
self.conv1d.weight,
|
||||||
|
{
|
||||||
|
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||||
|
[
|
||||||
|
query_key_settings,
|
||||||
|
query_key_settings,
|
||||||
|
value_settings,
|
||||||
|
],
|
||||||
|
self.tp_size,
|
||||||
|
self.tp_rank,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# selective projection used to make dt, B and C input dependant
|
||||||
|
|
||||||
|
# time step projection (discretization)
|
||||||
|
# instantiate once and copy inv_dt in init_weights of PretrainedModel
|
||||||
|
self.dt_bias = nn.Parameter(
|
||||||
|
torch.ones(self.num_v_heads // self.tp_size),
|
||||||
|
)
|
||||||
|
self.A_log = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
divide(self.num_v_heads, self.tp_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
|
||||||
|
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
|
||||||
|
|
||||||
|
self.norm = RMSNormGated(
|
||||||
|
self.head_v_dim,
|
||||||
|
eps=self.layer_norm_epsilon,
|
||||||
|
group_size=None,
|
||||||
|
norm_before_gate=True,
|
||||||
|
device=current_platform.current_device(),
|
||||||
|
dtype=config.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
self.value_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
||||||
|
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
def fix_query_key_value_ordering(
|
def fix_query_key_value_ordering(
|
||||||
self,
|
self,
|
||||||
mixed_qkvz: torch.Tensor,
|
mixed_qkv,
|
||||||
mixed_ba: torch.Tensor,
|
z,
|
||||||
|
b,
|
||||||
|
a,
|
||||||
):
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Qwen3.5 Series dont need to fix query key value ordering"
|
"Qwen3.5 Series dont need to fix query key value ordering"
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_qkvz_proj(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
key_dim: int,
|
|
||||||
value_dim: int,
|
|
||||||
quant_config: QuantizationConfig | None,
|
|
||||||
prefix: str,
|
|
||||||
) -> MergedColumnParallelLinear:
|
|
||||||
return MergedColumnParallelLinear(
|
|
||||||
input_size=hidden_size,
|
|
||||||
output_sizes=[key_dim, key_dim, value_dim, value_dim],
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
|||||||
# ============================================================
|
# ============================================================
|
||||||
# Part 1: Input Projection
|
# Part 1: Input Projection
|
||||||
# ============================================================
|
# ============================================================
|
||||||
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
mixed_qkv, _ = self.in_proj_qkv(hidden_states)
|
||||||
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
|
z, _ = self.in_proj_z(hidden_states)
|
||||||
z_size = self.value_dim // self.tp_size
|
|
||||||
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
|
|
||||||
z = z.reshape(z.size(0), -1, self.head_v_dim)
|
z = z.reshape(z.size(0), -1, self.head_v_dim)
|
||||||
ba, _ = self.in_proj_ba(hidden_states)
|
b, _ = self.in_proj_b(hidden_states)
|
||||||
b, a = ba.chunk(2, dim=-1)
|
a, _ = self.in_proj_a(hidden_states)
|
||||||
|
|
||||||
b = b.contiguous()
|
b = b.contiguous()
|
||||||
a = a.contiguous()
|
a = a.contiguous()
|
||||||
@@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel):
|
|||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
# self attention
|
|
||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
# mlp
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
# GDN
|
|
||||||
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
|
|
||||||
("in_proj_qkvz", "in_proj_z", 3),
|
|
||||||
("in_proj_ba", "in_proj_b", 0),
|
|
||||||
("in_proj_ba", "in_proj_a", 1),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import (
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNormGated
|
from vllm.model_executor.layers.layernorm import RMSNormGated
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
@@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||||
|
|
||||||
# projection of the input hidden states
|
# projection of the input hidden states
|
||||||
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
|
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
|
||||||
# we need to create qkvz_proj adaptively here.
|
self.projection_size_ba = self.num_v_heads * 2
|
||||||
self.in_proj_qkvz = self.create_qkvz_proj(
|
self.in_proj_qkvz = ColumnParallelLinear(
|
||||||
hidden_size=self.hidden_size,
|
input_size=self.hidden_size,
|
||||||
key_dim=self.key_dim,
|
output_size=self.projection_size_qkvz,
|
||||||
value_dim=self.value_dim,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.in_proj_qkvz",
|
prefix=f"{prefix}.in_proj_qkvz",
|
||||||
)
|
)
|
||||||
# ba_proj doesn't support blockwise fp8 quantization.
|
# ba_proj doesn't support blockwise fp8 quantization.
|
||||||
self.in_proj_ba = MergedColumnParallelLinear(
|
self.in_proj_ba = ColumnParallelLinear(
|
||||||
input_size=self.hidden_size,
|
input_size=self.hidden_size,
|
||||||
output_sizes=[self.num_v_heads] * 2,
|
output_size=self.projection_size_ba,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.in_proj_ba",
|
prefix=f"{prefix}.in_proj_ba",
|
||||||
@@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
def create_qkvz_proj(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
key_dim: int,
|
|
||||||
value_dim: int,
|
|
||||||
quant_config: QuantizationConfig | None,
|
|
||||||
prefix: str,
|
|
||||||
) -> MergedColumnParallelLinear:
|
|
||||||
return MergedColumnParallelLinear(
|
|
||||||
input_size=hidden_size,
|
|
||||||
output_sizes=[sum((key_dim, key_dim, value_dim)), value_dim],
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.in_proj_qkvz",
|
|
||||||
)
|
|
||||||
|
|
||||||
def fix_query_key_value_ordering(
|
def fix_query_key_value_ordering(
|
||||||
self,
|
self,
|
||||||
mixed_qkvz: torch.Tensor,
|
mixed_qkvz,
|
||||||
mixed_ba: torch.Tensor,
|
mixed_ba,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
|
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
|
||||||
|
|||||||
Reference in New Issue
Block a user