diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index b84b4e2ae..4050bf045 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -44,6 +44,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors import ( + compressed_tensors as ct, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -828,38 +831,74 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): name: str, loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - modules = name.split(".") - # Permute Q/K weights and corresponding scales for rotary embedding. - # This pathway is validated against modelopt and compressed-tensors ckpts, - # and for per-tensor, per-group (e.g. GPTQ), and per-channel quant schemes. - # Note: permutations are not feasible only for per-block (e.g. DeepSeek 128x128) - # For per-block quantization, consider not quantizing q/k_proj. - is_weight = modules[-1] in ("weight", "weight_packed") - is_weight_scale = ( - modules[-1] == "weight_scale" - and loaded_weight.numel() > 1 # no need to permute per-tensor scales - ) - is_k_proj = "wk" in modules or "k_proj" in modules - is_q_proj = "wq" in modules or "q_proj" in modules - - if (is_weight or is_weight_scale) and (is_k_proj or is_q_proj): - original_ndim = loaded_weight.ndim - if original_ndim == 1: - loaded_weight = loaded_weight.unsqueeze(-1) - - f_out, f_in = loaded_weight.shape - n_heads = ( - self.config.num_key_value_heads - if is_k_proj - else self.config.num_attention_heads + # Helper function to permute the weight's channels + def permute( + w: torch.Tensor, + n_heads: int, + is_nvfp4_weight_scale: bool, + is_ct_int8_or_fp8_weight_scale: bool, + ): + # Calculate the expected shape of the weight. + # Do not rely on w's shape, as it may be in another layout. + attn_in = self.config.head_dim * n_heads + attn_out = ( + self.config.hidden_size + if not is_ct_int8_or_fp8_weight_scale + else w.shape[-1] ) - loaded_weight = ( - loaded_weight.view(n_heads, f_out // n_heads // 2, 2, f_in) + + # If the weight is FP4 packed as uint8, we need to divide attn_out + # by 2. + if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out: + attn_out = attn_out // 2 + + # If the weight is a weight scale, we need to divide attn_out by + # block size, which is currently 16. + elif ( + w.dtype == torch.float8_e4m3fn + and is_nvfp4_weight_scale + and w.shape[1] * 16 == attn_out + ): + attn_out = attn_out // 16 + + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) .transpose(1, 2) - .reshape(f_out, f_in) + .reshape(attn_in, attn_out) ) - if original_ndim == 1: - loaded_weight = loaded_weight.squeeze(-1) + modules = name.split(".") + + # Permute Q/K weights and weight block scales for rotary embedding + is_weight = modules[-1] == "weight" + is_nvfp4_weight_scale = ( + modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn + ) + is_ct_int8_or_fp8_weight_scale = False + if modules[-1] == "weight_scale" and isinstance( + self.model.quant_config, ct.CompressedTensorsConfig + ): + from compressed_tensors import CompressionFormat + + is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [ + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value, + ] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32] + + if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale: + if "wk" in modules or "k_proj" in modules: + loaded_weight = permute( + loaded_weight, + self.config.num_key_value_heads, + is_nvfp4_weight_scale, + is_ct_int8_or_fp8_weight_scale, + ) + elif "wq" in modules or "q_proj" in modules: + loaded_weight = permute( + loaded_weight, + self.config.num_attention_heads, + is_nvfp4_weight_scale, + is_ct_int8_or_fp8_weight_scale, + ) return name, loaded_weight