[Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# MXFP8
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE4M3fn",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
|
||||
@@ -343,6 +343,8 @@ __global__ void Marlin(
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
|
||||
static_assert(group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
@@ -357,9 +359,10 @@ __global__ void Marlin(
|
||||
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
|
||||
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
is_a_8bit || b_type == vllm::kFE4M3fn ||
|
||||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
|
||||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
@@ -373,7 +376,7 @@ __global__ void Marlin(
|
||||
const int group_size =
|
||||
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
||||
const int scales_expert_stride =
|
||||
prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
@@ -692,9 +695,8 @@ __global__ void Marlin(
|
||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||
|
||||
// Scale sizes/strides without act_order
|
||||
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
constexpr int s_sh_stride =
|
||||
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
} else if (group_blocks >= b_sh_wr_iters) {
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||
} else {
|
||||
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
|
||||
@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
} else if (b_type_id == vllm::kFE4M3fn.id() &&
|
||||
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
|
||||
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# MXFP8
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE4M3fn",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
|
||||
@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
} else if (b_type_id == vllm::kFE4M3fn.id() &&
|
||||
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
|
||||
@@ -327,6 +327,9 @@ __global__ void Marlin(
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (s_type == vllm::kFE8M0fnu) {
|
||||
// MXFP8: FP8 weights with e8m0 microscaling block scales
|
||||
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
@@ -334,6 +337,7 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
constexpr bool is_a_8bit = a_type.size_bits() == 8;
|
||||
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
|
||||
if constexpr (!is_a_8bit) {
|
||||
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
|
||||
}
|
||||
@@ -343,7 +347,7 @@ __global__ void Marlin(
|
||||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
is_a_8bit || b_type == vllm::kFE4M3fn ||
|
||||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
|
||||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
@@ -555,9 +559,8 @@ __global__ void Marlin(
|
||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||
|
||||
// Scale sizes/strides without act_order
|
||||
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
constexpr int s_sh_stride =
|
||||
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
@@ -997,7 +1000,7 @@ __global__ void Marlin(
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
} else if (group_blocks >= b_sh_wr_iters) {
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||
} else {
|
||||
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
|
||||
@@ -151,6 +151,12 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# MXFP8
|
||||
{
|
||||
"a_type": [scalar_types.bfloat16],
|
||||
"b_type": scalar_types.float8_e4m3fn,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
|
||||
@@ -23,7 +23,7 @@ from tests.quantization.utils import is_quant_method_supported
|
||||
from ..utils import check_logprobs_close
|
||||
|
||||
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
|
||||
MOE_MODEL = "Qwen/Qwen3-30B-A3B"
|
||||
MOE_MODEL = "allenai/OLMoE-1B-7B-0125-Instruct"
|
||||
# A small dense model (no MoE) to validate the linear-only path.
|
||||
DENSE_MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Static,
|
||||
kMxfp8Static,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -582,6 +583,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Static,
|
||||
kMxfp8Static,
|
||||
kNvfp4Static,
|
||||
]
|
||||
return weight_key in SUPPORTED_W
|
||||
@@ -609,7 +611,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
|
||||
@property
|
||||
def quant_type_id(self) -> int:
|
||||
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
return scalar_types.uint4b8.id
|
||||
elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16:
|
||||
|
||||
@@ -436,13 +436,27 @@ def convert_to_fp8_moe_kernel_format(
|
||||
elif fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2)
|
||||
elif fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
)
|
||||
weight_block_size = getattr(layer, "weight_block_size", None)
|
||||
if weight_block_size == [1, 32]:
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_mxfp8_moe_layer_for_marlin,
|
||||
)
|
||||
|
||||
w13, w2, w13_scale, w2_scale = prepare_mxfp8_moe_layer_for_marlin(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
)
|
||||
else:
|
||||
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
)
|
||||
elif fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
|
||||
@@ -15,14 +15,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset(
|
||||
{
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
}
|
||||
_SUPPORTED_BACKENDS = (
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
)
|
||||
|
||||
_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = {
|
||||
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
"marlin": Fp8MoeBackend.MARLIN,
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,11 @@ def select_mxfp8_moe_backend(
|
||||
|
||||
# Auto-select: pick the first supported backend.
|
||||
for backend in _SUPPORTED_BACKENDS:
|
||||
try:
|
||||
experts_cls = _select_kernel_cls(backend, config)
|
||||
except ValueError:
|
||||
continue
|
||||
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
|
||||
return backend, _select_kernel_cls(backend, config)
|
||||
return backend, experts_cls
|
||||
|
||||
raise ValueError("No MXFP8 MoE backends available.")
|
||||
|
||||
@@ -67,10 +67,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
MXFP8_BLOCK_SIZE,
|
||||
MXFP8_SCALE_DTYPE,
|
||||
MXFP8_VALUE_DTYPE,
|
||||
Mxfp8LinearBackend,
|
||||
Mxfp8LinearOp,
|
||||
mxfp8_e4m3_quantize,
|
||||
swizzle_mxfp8_scale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
apply_nvfp4_linear,
|
||||
@@ -1499,8 +1497,8 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
|
||||
return 100
|
||||
# Marlin kernel supports MXFP8 on SM80+
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
@@ -1555,9 +1553,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
"Dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
|
||||
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
|
||||
self.mxfp8_linear_op = Mxfp8LinearOp()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1615,36 +1611,6 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
|
||||
"""Not swizzled - MXFP8 GEMM emulation"""
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
|
||||
# Slice weight_scale to match weight dimensions (handles padding)
|
||||
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
|
||||
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
|
||||
# 2D weight scale
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
# Swizzle the weight scales
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
|
||||
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(
|
||||
weight_scale_swizzled.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Validate weight tensor
|
||||
if layer.weight.ndim != 2:
|
||||
@@ -1669,14 +1635,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
f" got {layer.weight_scale.dtype}"
|
||||
)
|
||||
|
||||
if self.backend == Mxfp8LinearBackend.EMULATION:
|
||||
# Swizzled layout is not used
|
||||
self._process_weights_after_loading_scale_2d(layer)
|
||||
return
|
||||
|
||||
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
# Swizzled layout is required for Flashinfer CUTLASS
|
||||
self._process_weights_after_loading_scale_1d(layer)
|
||||
self.mxfp8_linear_op.process_weights(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -1684,22 +1643,15 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if layer.weight.dtype != MXFP8_VALUE_DTYPE:
|
||||
raise ValueError(
|
||||
f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
|
||||
)
|
||||
if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
|
||||
raise ValueError(
|
||||
f"Weight scale dtype {layer.weight_scale.dtype} != "
|
||||
f"expected {MXFP8_SCALE_DTYPE}"
|
||||
)
|
||||
|
||||
return self.mxfp8_linear_op.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
workspace=getattr(layer, "workspace", None),
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -34,10 +34,8 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
MXFP8_BLOCK_SIZE,
|
||||
Mxfp8LinearBackend,
|
||||
Mxfp8LinearOp,
|
||||
mxfp8_e4m3_quantize,
|
||||
swizzle_mxfp8_scale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
@@ -71,7 +69,8 @@ class Mxfp8Config(Fp8Config):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
# Marlin kernel supports MXFP8 on SM80+
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
|
||||
@@ -128,24 +127,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
|
||||
def __init__(self, quant_config: "Mxfp8Config"):
|
||||
self.quant_config = quant_config
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.mxfp8_linear = Mxfp8LinearOp(self._select_backend())
|
||||
logger.info_once(
|
||||
"Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_backend() -> Mxfp8LinearBackend:
|
||||
try:
|
||||
from vllm.utils import flashinfer as fi
|
||||
|
||||
_ = fi.mm_mxfp8
|
||||
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"FlashInfer mm_mxfp8 not available, "
|
||||
"falling back to MXFP8 emulation backend."
|
||||
)
|
||||
return Mxfp8LinearBackend.EMULATION
|
||||
self.mxfp8_linear = Mxfp8LinearOp()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -180,14 +162,12 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
|
||||
|
||||
weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())
|
||||
|
||||
if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
|
||||
N, K = layer.weight.shape[0], layer.weight.shape[1]
|
||||
weight_scale = swizzle_mxfp8_scale(weight_scale, N, K)
|
||||
|
||||
layer.input_scale = None
|
||||
replace_parameter(layer, "weight", weight_fp8.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
self.mxfp8_linear.process_weights(layer)
|
||||
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
def apply(
|
||||
@@ -202,6 +182,9 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
bias=bias,
|
||||
workspace=getattr(layer, "workspace", None),
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
)
|
||||
|
||||
|
||||
@@ -255,17 +238,24 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
|
||||
self, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
|
||||
num_batches = weight.size(0)
|
||||
w_quant = []
|
||||
w_scales = []
|
||||
for i in range(num_batches):
|
||||
mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize(
|
||||
E = weight.size(0)
|
||||
first_q, first_s = mxfp8_e4m3_quantize(weight[0], is_sf_swizzled_layout=False)
|
||||
# Pre-allocate the output tensors rather than stacking.
|
||||
# This is important for consistent memory layout.
|
||||
w_quant = torch.empty(
|
||||
(E, *first_q.shape), dtype=first_q.dtype, device=weight.device
|
||||
)
|
||||
w_scales = torch.empty(
|
||||
(E, *first_s.shape), dtype=first_s.dtype, device=weight.device
|
||||
)
|
||||
w_quant[0] = first_q
|
||||
w_scales[0] = first_s
|
||||
for i in range(1, E):
|
||||
w_quant[i], w_scales[i] = mxfp8_e4m3_quantize(
|
||||
weight[i], is_sf_swizzled_layout=False
|
||||
)
|
||||
w_quant.append(mx_fp8_quant)
|
||||
w_scales.append(mx_fp8_scale)
|
||||
|
||||
return torch.stack(w_quant), torch.stack(w_scales)
|
||||
return w_quant, w_scales
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
|
||||
@@ -336,6 +336,202 @@ def pack_fp8_to_int32(
|
||||
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
|
||||
|
||||
|
||||
def mxfp8_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Reorder scales for e8m0 kernel layout and convert to float8_e8m0fnu."""
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1
|
||||
)
|
||||
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def apply_mxfp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n,)
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(
|
||||
m=reshaped_x.size(0),
|
||||
n=size_n,
|
||||
k=size_k,
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
output = ops.marlin_gemm(
|
||||
a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_bias=bias,
|
||||
b_scales=weight_scale,
|
||||
a_scales=None,
|
||||
global_scale=None,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
perm=None,
|
||||
workspace=workspace,
|
||||
b_q_type=scalar_types.float8_e4m3fn,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
)
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_mxfp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
"""Repack MXFP8 weights and scales into Marlin kernel format.
|
||||
|
||||
Expects the layer to have:
|
||||
- weight: [N, K] float8_e4m3fn
|
||||
- weight_scale: [N, K//32] uint8 (e8m0 encoded)
|
||||
- input_size_per_partition / output_size_per_partition
|
||||
"""
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
group_size = 32 # MX standard block size
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WORKSPACE
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# WEIGHT - repack FP8 weights to Marlin format
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
qweight = pack_fp8_to_int32(layer.weight, size_k_first=False)
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
replace_parameter(layer, "weight", marlin_qweight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Convert uint8 scales -> e8m0fnu -> param_dtype for permutation
|
||||
# Scales are [N, K//32], need [K//32, N] for marlin_permute_scales
|
||||
param_dtype = torch.get_default_dtype()
|
||||
scales = layer.weight_scale.data[:part_size_n, : part_size_k // group_size]
|
||||
scales = scales.contiguous()
|
||||
scales = scales.view(torch.float8_e8m0fnu).to(param_dtype)
|
||||
scales = scales.T.contiguous()
|
||||
|
||||
# Permute scales to Marlin layout
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
# Reorder for e8m0 kernel layout and convert back to e8m0fnu
|
||||
marlin_scales = mxfp8_marlin_process_scales(marlin_scales)
|
||||
replace_parameter(layer, "weight_scale", marlin_scales)
|
||||
|
||||
# BIAS
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
assert layer.bias.shape == (part_size_n,)
|
||||
bias = marlin_permute_bias(layer.bias)
|
||||
replace_parameter(layer, "bias", bias)
|
||||
|
||||
|
||||
def prepare_mxfp8_moe_layer_for_marlin(
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Repack MXFP8 MoE weights and scales into Marlin kernel format.
|
||||
|
||||
Args:
|
||||
layer: MoE layer (used to read params_dtype and attach workspace).
|
||||
w13: [E, 2*N, K] float8_e4m3fn weights.
|
||||
w2: [E, K, N] float8_e4m3fn weights.
|
||||
w13_scale: [E, 2*N, K//32] uint8 e8m0 scales.
|
||||
w2_scale: [E, K, N//32] uint8 e8m0 scales.
|
||||
|
||||
Returns:
|
||||
(w13, w2, w13_scale, w2_scale) in Marlin format.
|
||||
"""
|
||||
group_size = 32
|
||||
e = w13.shape[0]
|
||||
w13_n = w13.shape[1]
|
||||
k = w13.shape[2]
|
||||
n = w2.shape[2]
|
||||
|
||||
device = w13.device
|
||||
param_dtype = torch.get_default_dtype()
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
|
||||
if "w13" in name:
|
||||
size_n, size_k = w13_n, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
assert weight.shape == (e, size_n, size_k)
|
||||
|
||||
tensor_list = []
|
||||
for i in range(e):
|
||||
qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
|
||||
qweight = qweight.T.contiguous()
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
tensor_list.append(marlin_qweight)
|
||||
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
|
||||
w13 = repack_weight(w13, "w13")
|
||||
w2 = repack_weight(w2, "w2")
|
||||
|
||||
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
|
||||
if "w13" in name:
|
||||
size_n, size_k = w13_n, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
tensor_list = []
|
||||
for i in range(e):
|
||||
s = scales[i][:size_n, : size_k // group_size].contiguous()
|
||||
s = s.view(torch.float8_e8m0fnu).to(param_dtype)
|
||||
s = s.T.contiguous()
|
||||
marlin_s = marlin_permute_scales(
|
||||
s=s,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size,
|
||||
)
|
||||
marlin_s = mxfp8_marlin_process_scales(marlin_s)
|
||||
tensor_list.append(marlin_s)
|
||||
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
|
||||
w13_scale = permute_scales(w13_scale, "w13")
|
||||
w2_scale = permute_scales(w2_scale, "w2")
|
||||
|
||||
return w13, w2, w13_scale, w2_scale
|
||||
|
||||
|
||||
def marlin_quant_fp8_torch(weight, group_size, input_dtype=None):
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
if is_a_8bit:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import flashinfer as vllm_flashinfer
|
||||
@@ -15,6 +16,7 @@ logger = init_logger(__name__)
|
||||
class Mxfp8LinearBackend(Enum):
|
||||
EMULATION = "emulation"
|
||||
FLASHINFER_CUTLASS = "flashinfer-cutlass"
|
||||
MARLIN = "marlin"
|
||||
|
||||
|
||||
# MXFP8 constants
|
||||
@@ -23,6 +25,28 @@ MXFP8_SCALE_DTYPE = torch.uint8
|
||||
MXFP8_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def select_mxfp8_linear_backend() -> Mxfp8LinearBackend:
|
||||
"""Select the best MXFP8 linear backend for the current device.
|
||||
|
||||
- SM100+ (Blackwell): FLASHINFER_CUTLASS (native MXFP8 W8A8 GEMM)
|
||||
- SM80+ (Ampere/Ada): MARLIN (MXFP8 W8A16 GEMM)
|
||||
- Otherwise: EMULATION (dequant to BF16 fallback)
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.has_device_capability(100):
|
||||
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
is_fp8_marlin_supported,
|
||||
)
|
||||
|
||||
if is_fp8_marlin_supported():
|
||||
return Mxfp8LinearBackend.MARLIN
|
||||
|
||||
return Mxfp8LinearBackend.EMULATION
|
||||
|
||||
|
||||
def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
|
||||
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
|
||||
scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8
|
||||
@@ -47,17 +71,71 @@ def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
|
||||
return sf_swizzled.contiguous().view(-1)
|
||||
|
||||
|
||||
def _mxfp8_e4m3_quantize_torch(
|
||||
x: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Naive MXFP8 quantization.
|
||||
For each block of 32 elements along the last dimension, compute a
|
||||
shared e8m0 scale (the biased exponent of the block-wise amax)
|
||||
and quantize each element to float8_e4m3fn.
|
||||
|
||||
Returns (quantized_values [same shape, fp8], scales uint8).
|
||||
Scale shape depends on is_sf_swizzled_layout:
|
||||
False -> [..., K//32] (row-major 2D)
|
||||
True -> [flat swizzled 1D]
|
||||
"""
|
||||
assert x.shape[-1] % MXFP8_BLOCK_SIZE == 0
|
||||
orig_shape = x.shape
|
||||
num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE
|
||||
|
||||
x_fp32 = x.to(torch.float32)
|
||||
x_blocked = x_fp32.view(*orig_shape[:-1], num_blocks, MXFP8_BLOCK_SIZE)
|
||||
|
||||
amax = x_blocked.abs().amax(dim=-1)
|
||||
amax = amax.clamp(min=torch.finfo(torch.float32).tiny)
|
||||
scale_biased = torch.floor(torch.log2(amax)) + 127.0
|
||||
scale_biased = scale_biased.clamp(0, 254)
|
||||
scales_uint8 = scale_biased.to(torch.uint8)
|
||||
|
||||
descale = torch.exp2(scale_biased - 127.0)
|
||||
x_scaled = x_blocked / descale.unsqueeze(-1)
|
||||
|
||||
x_fp8 = x_scaled.view(orig_shape).to(MXFP8_VALUE_DTYPE)
|
||||
|
||||
if x.ndim == 2:
|
||||
M, K = x.shape
|
||||
scales_uint8 = scales_uint8.view(M, -1)
|
||||
if is_sf_swizzled_layout:
|
||||
scales_uint8 = swizzle_mxfp8_scale(scales_uint8, M=M, K=K)
|
||||
elif x.ndim == 3:
|
||||
B, M, K = x.shape
|
||||
scales_uint8 = scales_uint8.view(B, M, -1)
|
||||
if is_sf_swizzled_layout:
|
||||
swizzled = []
|
||||
for i in range(B):
|
||||
swizzled.append(swizzle_mxfp8_scale(scales_uint8[i], M=M, K=K))
|
||||
scales_uint8 = torch.cat(swizzled)
|
||||
|
||||
return x_fp8, scales_uint8
|
||||
|
||||
|
||||
def _mxfp8_e4m3_quantize_impl(
|
||||
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
x_q, x_scales = flashinfer_mxfp8_quantize(
|
||||
x, is_sf_swizzled_layout=is_sf_swizzled_layout
|
||||
)
|
||||
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
|
||||
x_scales = x_scales.view(x.size(0), -1)
|
||||
return x_q, x_scales
|
||||
if current_platform.has_device_capability(100):
|
||||
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
|
||||
|
||||
x_q, x_scales = flashinfer_mxfp8_quantize(
|
||||
x, is_sf_swizzled_layout=is_sf_swizzled_layout
|
||||
)
|
||||
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
|
||||
x_scales = x_scales.view(x.size(0), -1)
|
||||
return x_q, x_scales
|
||||
|
||||
return _mxfp8_e4m3_quantize_torch(x, is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def mxfp8_e4m3_quantize(
|
||||
@@ -128,11 +206,51 @@ direct_register_custom_op(
|
||||
|
||||
|
||||
class Mxfp8LinearOp:
|
||||
def __init__(self, backend: Mxfp8LinearBackend):
|
||||
if backend not in Mxfp8LinearBackend:
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
def __init__(self):
|
||||
self.backend = select_mxfp8_linear_backend()
|
||||
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend)
|
||||
|
||||
self.backend = backend
|
||||
def process_weights(self, layer: torch.nn.Module) -> None:
|
||||
"""Process MXFP8 weights after loading into backend-specific format."""
|
||||
if self.backend == Mxfp8LinearBackend.MARLIN:
|
||||
self._process_weights_marlin(layer)
|
||||
elif self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
|
||||
self._process_weights_flashinfer_cutlass(layer)
|
||||
else:
|
||||
self._process_weights_emulation(layer)
|
||||
|
||||
def _process_weights_emulation(self, layer: torch.nn.Module) -> None:
|
||||
"""Keep scales as 2D uint8 for dequant-to-BF16 emulation."""
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
|
||||
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
def _process_weights_flashinfer_cutlass(self, layer: torch.nn.Module) -> None:
|
||||
"""Swizzle scales to F8_128x4 layout for flashinfer CUTLASS."""
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
weight_scale_2d = layer.weight_scale.data[:N, :scale_k].contiguous()
|
||||
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(
|
||||
weight_scale_swizzled.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
def _process_weights_marlin(self, layer: torch.nn.Module) -> None:
|
||||
"""Repack MXFP8 weights and scales into Marlin kernel format."""
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_mxfp8_layer_for_marlin,
|
||||
)
|
||||
|
||||
prepare_mxfp8_layer_for_marlin(layer)
|
||||
|
||||
def _apply_emulation(
|
||||
self,
|
||||
@@ -142,7 +260,6 @@ class Mxfp8LinearOp:
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Validate weight_scale dtype and shape (must be 2D for TORCH backend)
|
||||
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
|
||||
raise ValueError(
|
||||
f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
|
||||
@@ -219,6 +336,32 @@ class Mxfp8LinearOp:
|
||||
output_shape = (*input_shape[:-1], N)
|
||||
return output.view(output_shape)
|
||||
|
||||
def _apply_marlin(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
*,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_mxfp8_marlin_linear,
|
||||
)
|
||||
|
||||
return apply_mxfp8_marlin_linear(
|
||||
input=input,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
workspace=workspace,
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
@@ -226,10 +369,27 @@ class Mxfp8LinearOp:
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
*,
|
||||
workspace: torch.Tensor | None = None,
|
||||
size_n: int = 0,
|
||||
size_k: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if self.backend == Mxfp8LinearBackend.EMULATION:
|
||||
return self._apply_emulation(input, weight, weight_scale, out_dtype, bias)
|
||||
|
||||
if self.backend == Mxfp8LinearBackend.MARLIN:
|
||||
assert workspace is not None
|
||||
return self._apply_marlin(
|
||||
input,
|
||||
weight,
|
||||
weight_scale,
|
||||
out_dtype,
|
||||
bias,
|
||||
workspace=workspace,
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
)
|
||||
|
||||
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
return self._apply_flashinfer_cutlass(
|
||||
input, weight, weight_scale, out_dtype, bias
|
||||
|
||||
Reference in New Issue
Block a user