Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -32,6 +32,7 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
"""Inference-only Flash model compatible with HuggingFace weights."""
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Optional, Union
|
||||
@@ -47,29 +48,37 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
block_dequant)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import block_dequant
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
from .utils import (
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashConfig(PretrainedConfig):
|
||||
"""Flash model configuration."""
|
||||
|
||||
model_type = "longcat_flash"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
@@ -132,8 +141,9 @@ class FlashConfig(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = (num_hidden_layers if num_hidden_layers
|
||||
is not None else num_layers)
|
||||
self.num_hidden_layers = (
|
||||
num_hidden_layers if num_hidden_layers is not None else num_layers
|
||||
)
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ep_size = ep_size
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
@@ -162,8 +172,11 @@ class FlashConfig(PretrainedConfig):
|
||||
self.zero_expert_type = zero_expert_type
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.hidden_act = "silu"
|
||||
self.intermediate_size = self.ffn_hidden_size if hasattr(
|
||||
self, "ffn_hidden_size") else self.intermediate_size
|
||||
self.intermediate_size = (
|
||||
self.ffn_hidden_size
|
||||
if hasattr(self, "ffn_hidden_size")
|
||||
else self.intermediate_size
|
||||
)
|
||||
if hasattr(self, "moe_intermediate_size"):
|
||||
self.moe_intermediate_size = self.moe_intermediate_size
|
||||
elif hasattr(self, "expert_ffn_hidden_size"):
|
||||
@@ -201,8 +214,9 @@ class FlashMLP(nn.Module):
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -216,15 +230,19 @@ class FlashMLP(nn.Module):
|
||||
|
||||
|
||||
class LongcatRouter(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
zero_expert_num=0,
|
||||
rounter_params_dtype=torch.bfloat16,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
zero_expert_num=0,
|
||||
rounter_params_dtype=torch.bfloat16,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_routed_experts = config.n_routed_experts if hasattr(
|
||||
config, "n_routed_experts") else config.num_experts[0]
|
||||
self.n_routed_experts = (
|
||||
config.n_routed_experts
|
||||
if hasattr(config, "n_routed_experts")
|
||||
else config.num_experts[0]
|
||||
)
|
||||
self.n_routed_experts = self.n_routed_experts + zero_expert_num
|
||||
self.classifier = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
@@ -235,7 +253,8 @@ class LongcatRouter(nn.Module):
|
||||
prefix=f"{prefix}.classifier",
|
||||
)
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype))
|
||||
torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits, _ = self.classifier(hidden_states)
|
||||
@@ -243,7 +262,6 @@ class LongcatRouter(nn.Module):
|
||||
|
||||
|
||||
class LongcatMoe(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlashConfig,
|
||||
@@ -271,7 +289,8 @@ class LongcatMoe(nn.Module):
|
||||
config=config,
|
||||
zero_expert_num=self.zero_expert_num,
|
||||
rounter_params_dtype=self.rounter_params_dtype,
|
||||
prefix=f"{prefix}.gate")
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
@@ -291,14 +310,13 @@ class LongcatMoe(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits = self.router(hidden_states.to(
|
||||
self.rounter_params_dtype))
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
router_logits = self.router(hidden_states.to(self.rounter_params_dtype))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
@@ -316,67 +334,76 @@ class FlashDecoderLayer(nn.Module):
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = int(prefix.split(sep=".")[-1])
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
|
||||
# Dual attention structure
|
||||
self.self_attn = nn.ModuleList([
|
||||
DeepseekV2MLAAttention(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=(config.q_lora_rank if hasattr(
|
||||
config, "q_lora_rank") else None),
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=None if "self_attn" in getattr(
|
||||
config, "disable_quant_module", []) else quant_config,
|
||||
prefix=f"{prefix}.self_attn.{i}",
|
||||
) for i in range(2)
|
||||
])
|
||||
self.input_layernorm = nn.ModuleList([
|
||||
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
for i in range(2)
|
||||
])
|
||||
self.post_attention_layernorm = nn.ModuleList([
|
||||
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
for i in range(2)
|
||||
])
|
||||
self.self_attn = nn.ModuleList(
|
||||
[
|
||||
DeepseekV2MLAAttention(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=(
|
||||
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
||||
),
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=None
|
||||
if "self_attn" in getattr(config, "disable_quant_module", [])
|
||||
else quant_config,
|
||||
prefix=f"{prefix}.self_attn.{i}",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
)
|
||||
self.input_layernorm = nn.ModuleList(
|
||||
[RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)]
|
||||
)
|
||||
self.post_attention_layernorm = nn.ModuleList(
|
||||
[RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)]
|
||||
)
|
||||
|
||||
# Dual MLP structure
|
||||
self.mlps = nn.ModuleList([
|
||||
FlashMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=None if "mlps" in getattr(
|
||||
config, "disable_quant_module", []) else quant_config,
|
||||
prefix=f"{prefix}.mlps.{i}",
|
||||
) for i in range(2)
|
||||
])
|
||||
self.mlps = nn.ModuleList(
|
||||
[
|
||||
FlashMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=None
|
||||
if "mlps" in getattr(config, "disable_quant_module", [])
|
||||
else quant_config,
|
||||
prefix=f"{prefix}.mlps.{i}",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
self.mlp = LongcatMoe(
|
||||
config=config,
|
||||
num_experts=config.n_routed_experts if hasattr(
|
||||
config, "n_routed_experts") else
|
||||
config.num_experts[self.layer_idx],
|
||||
num_experts=config.n_routed_experts
|
||||
if hasattr(config, "n_routed_experts")
|
||||
else config.num_experts[self.layer_idx],
|
||||
top_k=config.moe_topk
|
||||
if hasattr(config, "moe_topk") else config.num_experts_per_tok,
|
||||
if hasattr(config, "moe_topk")
|
||||
else config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
quant_config=quant_config,
|
||||
@@ -389,13 +416,11 @@ class FlashDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm[0](hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm[0](hidden_states,
|
||||
residual)
|
||||
hidden_states, residual = self.input_layernorm[0](hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn[0](
|
||||
positions=positions,
|
||||
@@ -403,7 +428,8 @@ class FlashDecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm[0](
|
||||
hidden_states, residual)
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# moe
|
||||
hidden_states_copy = hidden_states.clone()
|
||||
@@ -412,8 +438,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
# first mlp
|
||||
hidden_states = self.mlps[0](hidden_states)
|
||||
|
||||
hidden_states, residual = self.input_layernorm[1](hidden_states,
|
||||
residual)
|
||||
hidden_states, residual = self.input_layernorm[1](hidden_states, residual)
|
||||
|
||||
# second_attn
|
||||
hidden_states = self.self_attn[1](
|
||||
@@ -421,7 +446,8 @@ class FlashDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm[1](
|
||||
hidden_states, residual)
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# second_mlp
|
||||
hidden_states = self.mlps[1](hidden_states)
|
||||
@@ -462,14 +488,15 @@ class FlashModel(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -501,10 +528,9 @@ class FlashModel(nn.Module):
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -532,26 +558,32 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
config.intermediate_size = config.ffn_hidden_size if hasattr(
|
||||
config, "ffn_hidden_size") else config.intermediate_size
|
||||
config.intermediate_size = (
|
||||
config.ffn_hidden_size
|
||||
if hasattr(config, "ffn_hidden_size")
|
||||
else config.intermediate_size
|
||||
)
|
||||
self.lora_config = lora_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = FlashModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = FlashModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
@@ -563,8 +595,9 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@@ -581,14 +614,12 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts if hasattr(
|
||||
self.config, "n_routed_experts") else
|
||||
self.config.num_experts[0],
|
||||
num_experts=self.config.n_routed_experts
|
||||
if hasattr(self.config, "n_routed_experts")
|
||||
else self.config.num_experts[0],
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
@@ -610,8 +641,9 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (name.endswith(".bias")
|
||||
or name.endswith("_bias")) and name not in params_dict:
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
# Skip mtp
|
||||
if ".mtp." in name:
|
||||
@@ -633,22 +665,25 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# Skip mtp
|
||||
if ".mtp." in name_mapped:
|
||||
continue
|
||||
if (name_mapped.endswith(".bias")
|
||||
or name_mapped.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
if (
|
||||
name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name_mapped]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
@@ -672,8 +707,9 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
@@ -681,35 +717,35 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if isinstance(self.model.layers[layer_id], PPMissingLayer):
|
||||
continue
|
||||
self_attn = self.model.layers[layer_id].self_attn[i]
|
||||
if hasattr(self.quant_config, "weight_block_size"
|
||||
) and self_attn.kv_b_proj.weight.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
if hasattr(
|
||||
self.quant_config, "weight_block_size"
|
||||
) and self_attn.kv_b_proj.weight.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
dtype = torch.get_default_dtype()
|
||||
w = block_dequant(self_attn.kv_b_proj.weight,
|
||||
self_attn.kv_b_proj.weight_scale_inv,
|
||||
weight_block_size).to(dtype)
|
||||
w = block_dequant(
|
||||
self_attn.kv_b_proj.weight,
|
||||
self_attn.kv_b_proj.weight_scale_inv,
|
||||
weight_block_size,
|
||||
).to(dtype)
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0,
|
||||
(-1,
|
||||
self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split(
|
||||
[self_attn.qk_nope_head_dim, self_attn.v_head_dim],
|
||||
dim=1)
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(
|
||||
1, 2)
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if self.config.mla_scale_q_lora:
|
||||
self_attn.q_a_layernorm.weight.data *= (
|
||||
self.config.hidden_size / self.config.q_lora_rank)**0.5
|
||||
self.config.hidden_size / self.config.q_lora_rank
|
||||
) ** 0.5
|
||||
if self.config.mla_scale_kv_lora:
|
||||
self_attn.kv_a_layernorm.weight.data *= (
|
||||
self.config.hidden_size /
|
||||
self.config.kv_lora_rank)**0.5
|
||||
self.config.hidden_size / self.config.kv_lora_rank
|
||||
) ** 0.5
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user