Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8 (#35053)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -553,6 +553,83 @@ if has_flashinfer():
|
||||
rounded_m, rounded_n, dtype=torch.uint8, device=a.device
|
||||
)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::mm_mxfp8",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def mm_mxfp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
backend: str = "cutlass",
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import mm_mxfp8 as mm_mxfp8_
|
||||
|
||||
return mm_mxfp8_(
|
||||
A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
out=None,
|
||||
out_dtype=out_dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
@torch.library.register_fake(
|
||||
"vllm::mm_mxfp8",
|
||||
)
|
||||
def mm_mxfp8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
backend: str = "cutlass",
|
||||
) -> torch.Tensor:
|
||||
# A is [m, k], B is [k, n] -> output [m, n]
|
||||
return torch.empty(A.shape[0], B.shape[1], dtype=out_dtype, device=A.device)
|
||||
|
||||
|
||||
def flashinfer_mm_mxfp8(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
backend: str = "cutlass",
|
||||
) -> torch.Tensor:
|
||||
"""MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API.
|
||||
|
||||
Takes non-transposed weights and handles transpose internally.
|
||||
|
||||
CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal
|
||||
performance and accuracy. Both input and weight scales should be in
|
||||
swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True).
|
||||
"""
|
||||
# a shape [M, K]
|
||||
# b shape [K, N]
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert a.shape[1] == b.shape[1] # K dimension must match
|
||||
|
||||
if block_scale_b.ndim != 1:
|
||||
raise ValueError(
|
||||
"mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; "
|
||||
f"got shape={tuple(block_scale_b.shape)}"
|
||||
)
|
||||
|
||||
# Output tensor [M, N]
|
||||
return mm_mxfp8(
|
||||
a,
|
||||
b.t(), # Transpose weight: [N, K] -> [K, N]
|
||||
block_scale_a,
|
||||
block_scale_b,
|
||||
out_dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp4_mm(
|
||||
a: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user