diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f4038801c..9b9a0858a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -964,12 +964,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) elif self.use_marlin: - prepare_moe_fp8_layer_for_marlin( - layer, False, input_dtype=self.marlin_input_dtype + ( + workspace, + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + ) = prepare_moe_fp8_layer_for_marlin( + layer, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + input_dtype=self.marlin_input_dtype, ) - # Activations not quantized for marlin. - del layer.w13_input_scale - del layer.w2_input_scale + layer.workspace = workspace + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, "w13_weight_scale", w13_weight_scale) + replace_parameter(layer, "w2_weight_scale", w2_weight_scale) if self.use_cutlass: assert self.weight_quant.strategy != QuantizationStrategy.BLOCK diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3d571a77d..e77b94db8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -912,6 +912,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( w13_weight, w2_weight ) + elif self.fp8_backend == Fp8MoeBackend.MARLIN: + ( + workspace, + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + ) = prepare_moe_fp8_layer_for_marlin( + layer, + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + input_dtype=self.marlin_input_dtype, + ) + layer.workspace = workspace + elif self.fp8_backend in [ Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_TRTLLM, @@ -937,17 +954,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale) - # TODO(rob): we do this after replace_parameter() because - # prepare_moe_fp8_layer_for_marlin uses on the layer's params - # directly. We will refactor this in a follow up PR. - if self.fp8_backend == Fp8MoeBackend.MARLIN: - prepare_moe_fp8_layer_for_marlin( - layer, False, input_dtype=self.marlin_input_dtype - ) - # Activations not quantized for marlin. - del layer.w13_input_scale - del layer.w2_input_scale - def _setup_kernel(self, layer: Module) -> None: """Setup Modular Kernel for TP Case""" # NOTE(rob): this is a WIP refactor. We are first migrating @@ -1194,20 +1200,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) -> FusedMoEQuantConfig | None: if self.fp8_backend == Fp8MoeBackend.MARLIN: return fp8_w8a16_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, + w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), + w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), block_shape=self.weight_block_size, ) return fp8_w8a8_moe_quant_config( - w1_scale=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_scale=( - layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale - ), + w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), + w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.weight_block_size, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 819704803..15d37f7d3 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -315,10 +315,26 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) elif self.use_marlin: - prepare_moe_fp8_layer_for_marlin(layer, False) - # Activations not quantized for marlin. - del layer.w13_input_scale - del layer.w2_input_scale + (workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = ( + prepare_moe_fp8_layer_for_marlin( + layer, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + ) + ) + layer.workspace = workspace + # TODO(rob): once we apply refactor to Quark, switch to using + # replace_parameter for compatibility with reloading in RL. + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 4d2f2fd71..a8d9db224 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -199,9 +199,18 @@ def prepare_fp8_layer_for_marlin( def prepare_moe_fp8_layer_for_marlin( layer: torch.nn.Module, - size_k_first: bool = True, + w13_weight: torch.Tensor, + w2_weight: torch.Tensor, + w13_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, input_dtype: torch.dtype | None = None, -) -> None: +) -> tuple[ + torch.Tensor, # workspace + torch.Tensor, # w13_weight + torch.Tensor, # w2_weight + torch.Tensor, # w13_weight_scale + torch.Tensor, # w2_weight_scale +]: logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " @@ -209,7 +218,7 @@ def prepare_moe_fp8_layer_for_marlin( "performance for compute-heavy workloads." ) if input_dtype is not None and input_dtype.itemsize == 1: - raise RuntimeError("Marlin W8A8 is not supported.") + raise NotImplementedError("Marlin W8A8 is not supported.") e = layer.num_experts k = layer.hidden_size @@ -218,53 +227,40 @@ def prepare_moe_fp8_layer_for_marlin( # WORKSPACE device = layer.w13_weight.device - layer.workspace = marlin_make_workspace_new(device, 4) + workspace = marlin_make_workspace_new(device, 4) perm = torch.empty(0, dtype=torch.int, device=device) # WEIGHT # Repack weights to marlin format - for name in ["w13_weight", "w2_weight"]: - weight = getattr(layer, name) + def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor: tensor_list = [] if "w13" in name: size_n, size_k = n * 2, k else: size_n, size_k = k, n - if size_k_first: - assert weight.shape == (e, size_k, size_n) - else: - assert weight.shape == (e, size_n, size_k) + assert weight.shape == (e, size_n, size_k) for i in range(e): - qweight = pack_fp8_to_int32(weight[i], size_k_first) - if not size_k_first: - qweight = qweight.T.contiguous() + qweight = pack_fp8_to_int32(weight[i], size_k_first=False) + qweight = qweight.T.contiguous() marlin_qweight = ops.gptq_marlin_repack( b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8 ) tensor_list.append(marlin_qweight) - weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - weight = torch.nn.Parameter(weight, requires_grad=False) + return torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - setattr(layer, name, weight) + w13_weight = repack_weight("w13", w13_weight) + w2_weight = repack_weight("w2", w2_weight) # WEIGHT SCALES # Permute scales group_size = -1 if weight_block_size is None else weight_block_size[1] - for name in ["w13", "w2"]: - if name + "_weight_scale" in dir(layer): - new_name = name + "_weight_scale" - scales = getattr(layer, new_name).to(layer.orig_dtype) - delattr(layer, new_name) - elif name + "_weight_scale_inv" in dir(layer): - new_name = name + "_weight_scale_inv" - scales = getattr(layer, new_name).to(layer.orig_dtype) - delattr(layer, new_name) - + def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor: + scales = scales.to(layer.orig_dtype) tensor_list = [] if "w13" in name: size_n, size_k = n * 2, k @@ -294,8 +290,7 @@ def prepare_moe_fp8_layer_for_marlin( # block-wise quantization -> group-wise quantization # (e, size_k // block_size[1], ceil(size_n / block_size[0])) # =>(repeat)=> (e, size_k // block_size[1], size_n) - if not size_k_first: - scales = scales.permute(0, 2, 1) + scales = scales.permute(0, 2, 1) block_n = weight_block_size[0] scales = scales.repeat_interleave(block_n, 2) # size_n may not divisible by block_size[0] @@ -310,26 +305,18 @@ def prepare_moe_fp8_layer_for_marlin( scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) if input_dtype != torch.float8_e4m3fn: scales = fp8_fused_exponent_bias_into_scales(scales) - scales = torch.nn.Parameter(scales, requires_grad=False) + return scales - setattr(layer, name + "_weight_scale", scales) + w13_weight_scale = permute_scales(w13_weight_scale, "w13") + w2_weight_scale = permute_scales(w2_weight_scale, "w2") - # BIAS - # Permute bias - for name in ["w13_bias", "w2_bias"]: - if not hasattr(layer, name): - continue - bias = getattr(layer, name).to(layer.orig_dtype) - - tensor_list = [] - for i in range(e): - expert_bias = bias[i] - - tensor_list.append(marlin_permute_bias(expert_bias)) - - bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - bias = torch.nn.Parameter(bias, requires_grad=False) - setattr(layer, name, bias) + return ( + workspace, + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + ) def pack_fp8_to_int32(