[Kernel][Quantization] Integrate block-quantized CUTLASS kernels for DeepSeekV3 (#12587)

Integrates the block-quantized kernels introduced in
https://github.com/vllm-project/vllm/pull/11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2025-01-31 18:29:11 -05:00
committed by GitHub
parent 145c2ff648
commit eb5741ad42
8 changed files with 160 additions and 37 deletions

View File

@@ -21,7 +21,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
@@ -133,6 +134,7 @@ class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
@@ -359,6 +361,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
)
return apply_fp8_linear(