Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8 (#35053)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
danisereb
2026-02-24 17:45:13 +02:00
committed by GitHub
parent a0c7081695
commit 9609b1f18d
3 changed files with 230 additions and 11 deletions

View File

@@ -553,6 +553,83 @@ if has_flashinfer():
rounded_m, rounded_n, dtype=torch.uint8, device=a.device
)
@torch.library.custom_op(
"vllm::mm_mxfp8",
mutates_args=[],
device_types="cuda",
)
def mm_mxfp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: torch.dtype,
backend: str = "cutlass",
) -> torch.Tensor:
from flashinfer import mm_mxfp8 as mm_mxfp8_
return mm_mxfp8_(
A,
B,
A_scale,
B_scale,
out=None,
out_dtype=out_dtype,
backend=backend,
)
@torch.library.register_fake(
"vllm::mm_mxfp8",
)
def mm_mxfp8_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: torch.dtype,
backend: str = "cutlass",
) -> torch.Tensor:
# A is [m, k], B is [k, n] -> output [m, n]
return torch.empty(A.shape[0], B.shape[1], dtype=out_dtype, device=A.device)
def flashinfer_mm_mxfp8(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
out_dtype: torch.dtype,
backend: str = "cutlass",
) -> torch.Tensor:
"""MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API.
Takes non-transposed weights and handles transpose internally.
CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal
performance and accuracy. Both input and weight scales should be in
swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True).
"""
# a shape [M, K]
# b shape [K, N]
assert a.ndim == 2 and b.ndim == 2
assert a.shape[1] == b.shape[1] # K dimension must match
if block_scale_b.ndim != 1:
raise ValueError(
"mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; "
f"got shape={tuple(block_scale_b.shape)}"
)
# Output tensor [M, N]
return mm_mxfp8(
a,
b.t(), # Transpose weight: [N, K] -> [K, N]
block_scale_a,
block_scale_b,
out_dtype,
backend=backend,
)
def flashinfer_scaled_fp4_mm(
a: torch.Tensor,