diff --git a/tests/models/multimodal/generation/test_maverick.py b/tests/models/multimodal/generation/test_maverick.py index 6fc2efa41..ff6e523e5 100644 --- a/tests/models/multimodal/generation/test_maverick.py +++ b/tests/models/multimodal/generation/test_maverick.py @@ -305,10 +305,10 @@ def create_text_model_weights(text_config: dict[str, Any]) -> dict[str, torch.Te # Self-attention weights (separate q, k, v projections) weights[f"{layer_prefix}.self_attn.q_proj.weight"] = torch.randn( - hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16 + num_attention_heads * head_dim, hidden_size, dtype=torch.bfloat16 ) weights[f"{layer_prefix}.self_attn.k_proj.weight"] = torch.randn( - hidden_size, num_key_value_heads * head_dim, dtype=torch.bfloat16 + num_key_value_heads * head_dim, hidden_size, dtype=torch.bfloat16 ) weights[f"{layer_prefix}.self_attn.v_proj.weight"] = torch.randn( num_key_value_heads * head_dim, hidden_size, dtype=torch.bfloat16 diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 4050bf045..b84b4e2ae 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -44,9 +44,6 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.compressed_tensors import ( - compressed_tensors as ct, -) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -831,74 +828,38 @@ class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): name: str, loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - # Helper function to permute the weight's channels - def permute( - w: torch.Tensor, - n_heads: int, - is_nvfp4_weight_scale: bool, - is_ct_int8_or_fp8_weight_scale: bool, - ): - # Calculate the expected shape of the weight. - # Do not rely on w's shape, as it may be in another layout. - attn_in = self.config.head_dim * n_heads - attn_out = ( - self.config.hidden_size - if not is_ct_int8_or_fp8_weight_scale - else w.shape[-1] - ) - - # If the weight is FP4 packed as uint8, we need to divide attn_out - # by 2. - if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out: - attn_out = attn_out // 2 - - # If the weight is a weight scale, we need to divide attn_out by - # block size, which is currently 16. - elif ( - w.dtype == torch.float8_e4m3fn - and is_nvfp4_weight_scale - and w.shape[1] * 16 == attn_out - ): - attn_out = attn_out // 16 - - return ( - w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) - .transpose(1, 2) - .reshape(attn_in, attn_out) - ) - modules = name.split(".") - - # Permute Q/K weights and weight block scales for rotary embedding - is_weight = modules[-1] == "weight" - is_nvfp4_weight_scale = ( - modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn + # Permute Q/K weights and corresponding scales for rotary embedding. + # This pathway is validated against modelopt and compressed-tensors ckpts, + # and for per-tensor, per-group (e.g. GPTQ), and per-channel quant schemes. + # Note: permutations are not feasible only for per-block (e.g. DeepSeek 128x128) + # For per-block quantization, consider not quantizing q/k_proj. + is_weight = modules[-1] in ("weight", "weight_packed") + is_weight_scale = ( + modules[-1] == "weight_scale" + and loaded_weight.numel() > 1 # no need to permute per-tensor scales ) - is_ct_int8_or_fp8_weight_scale = False - if modules[-1] == "weight_scale" and isinstance( - self.model.quant_config, ct.CompressedTensorsConfig - ): - from compressed_tensors import CompressionFormat + is_k_proj = "wk" in modules or "k_proj" in modules + is_q_proj = "wq" in modules or "q_proj" in modules - is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [ - CompressionFormat.int_quantized.value, - CompressionFormat.float_quantized.value, - ] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32] + if (is_weight or is_weight_scale) and (is_k_proj or is_q_proj): + original_ndim = loaded_weight.ndim + if original_ndim == 1: + loaded_weight = loaded_weight.unsqueeze(-1) - if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale: - if "wk" in modules or "k_proj" in modules: - loaded_weight = permute( - loaded_weight, - self.config.num_key_value_heads, - is_nvfp4_weight_scale, - is_ct_int8_or_fp8_weight_scale, - ) - elif "wq" in modules or "q_proj" in modules: - loaded_weight = permute( - loaded_weight, - self.config.num_attention_heads, - is_nvfp4_weight_scale, - is_ct_int8_or_fp8_weight_scale, - ) + f_out, f_in = loaded_weight.shape + n_heads = ( + self.config.num_key_value_heads + if is_k_proj + else self.config.num_attention_heads + ) + loaded_weight = ( + loaded_weight.view(n_heads, f_out // n_heads // 2, 2, f_in) + .transpose(1, 2) + .reshape(f_out, f_in) + ) + + if original_ndim == 1: + loaded_weight = loaded_weight.squeeze(-1) return name, loaded_weight