[Kernel] Add nvfp4 gemm flashinfer backends (#22346)
Signed-off-by: Julien Lin <jullin@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -24,6 +25,13 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self):
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
else:
|
||||
self.backend = "cutlass"
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
@@ -108,16 +116,36 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
layer.weight_global_scale.max().to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
# required by cutlass kernel; need Parameter, not ModelWeightParameter
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
weight = layer.weight_packed.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.alpha = Parameter(layer.input_global_scale *
|
||||
layer.weight_global_scale,
|
||||
requires_grad=False)
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8),
|
||||
epilogue_tile_m)
|
||||
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
|
||||
torch.uint8), epilogue_tile_m).reshape(
|
||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||
requires_grad=False)
|
||||
|
||||
layer.alpha = Parameter(
|
||||
1 / (layer.input_global_scale * layer.weight_global_scale),
|
||||
requires_grad=False)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -128,7 +156,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
out = run_nvfp4_emulations(
|
||||
x=x,
|
||||
input_global_scale=layer.input_global_scale,
|
||||
weight=layer.weight,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale_swizzled=layer.weight_scale_swizzled,
|
||||
weight_global_scale=layer.weight_global_scale)
|
||||
if bias is not None:
|
||||
@@ -136,14 +164,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
return out
|
||||
|
||||
output_dtype = x.dtype
|
||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||
|
||||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||
layer.weight_scale_swizzled,
|
||||
1 / layer.alpha, output_dtype)
|
||||
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
|
||||
layer.weight_scale_swizzled, layer.alpha, output_dtype)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
else:
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
@@ -38,7 +38,8 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import next_power_of_2
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
|
||||
has_flashinfer_moe)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -724,16 +725,20 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
self.use_marlin = True
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
if envs.VLLM_USE_TRTLLM_FP4_GEMM:
|
||||
assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
|
||||
self.backend = "flashinfer-trtllm"
|
||||
elif has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif is_fp4_marlin_supported():
|
||||
self.backend = "marlin"
|
||||
else:
|
||||
raise ValueError("Current platform does not support NVFP4"
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -815,17 +820,38 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
# block_size = 16;
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Block scale must be represented as FP8-E4M3")
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
weight = layer.weight.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8),
|
||||
epilogue_tile_m)
|
||||
weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
|
||||
torch.uint8), epilogue_tile_m).reshape(
|
||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -833,7 +859,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_marlin:
|
||||
if self.backend == "marlin":
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
@@ -859,9 +885,21 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
|
||||
assert (layer.alpha.dtype == torch.float32)
|
||||
|
||||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
|
||||
layer.weight_scale_swizzled, layer.alpha,
|
||||
output_dtype)
|
||||
mm_args = (
|
||||
x_fp4,
|
||||
layer.weight,
|
||||
x_blockscale,
|
||||
layer.weight_scale_swizzled,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
|
||||
else:
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
Reference in New Issue
Block a user