[XPU] add xpu backend implementation of mxfp8 quant (#38682)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
zofia
2026-04-08 08:30:35 +08:00
committed by GitHub
parent 70406eb1dc
commit ad3304425b
2 changed files with 54 additions and 0 deletions

View File

@@ -102,6 +102,48 @@ def _xpu_ops_deepseek_scaling_rope_fake(
return query, key
def _xpu_mxfp8_quantize_impl(
x: torch.Tensor, dtype: torch.dtype | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
MXFP8_BLOCK_SIZE = 32
assert x.shape[-1] % MXFP8_BLOCK_SIZE == 0
if dtype is not None:
assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), (
f"Unsupported dtype for xpu_mxfp8_quantize: {dtype}. "
f"Expected torch.float8_e4m3fn or torch.float8_e5m2."
)
else:
dtype = current_platform.fp8_dtype()
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
eps = 1e-10
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
torch.ops._C.per_token_group_fp8_quant(
x, x_q, x_s, MXFP8_BLOCK_SIZE, eps, fp8_min, fp8_max, True
)
x_s = x_s.to(torch.float8_e8m0fnu)
return x_q, x_s
def _xpu_mxfp8_quantize_fake(
x: torch.Tensor, dtype: torch.dtype | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if dtype is None:
dtype = current_platform.fp8_dtype()
MXFP8_BLOCK_SIZE = 32
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
x_s = torch.zeros(shape, device=x.device, dtype=torch.float32)
return x.to(dtype), x_s.to(torch.float8_e8m0fnu)
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@@ -504,6 +546,12 @@ class xpu_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="xpu_mxfp8_quantize",
op_func=_xpu_mxfp8_quantize_impl,
fake_impl=_xpu_mxfp8_quantize_fake,
)
_OPS_REGISTERED = True

View File

@@ -205,6 +205,12 @@ direct_register_custom_op(
)
def xpu_mxfp8_quantize(
x: torch.Tensor, dtype: torch.dtype | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.xpu_mxfp8_quantize(x, dtype)
class Mxfp8LinearOp:
def __init__(self):
self.backend = select_mxfp8_linear_backend()