[XPU] add xpu backend implementation of mxfp8 quant (#38682)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -102,6 +102,48 @@ def _xpu_ops_deepseek_scaling_rope_fake(
|
||||
return query, key
|
||||
|
||||
|
||||
def _xpu_mxfp8_quantize_impl(
|
||||
x: torch.Tensor, dtype: torch.dtype | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
MXFP8_BLOCK_SIZE = 32
|
||||
assert x.shape[-1] % MXFP8_BLOCK_SIZE == 0
|
||||
if dtype is not None:
|
||||
assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), (
|
||||
f"Unsupported dtype for xpu_mxfp8_quantize: {dtype}. "
|
||||
f"Expected torch.float8_e4m3fn or torch.float8_e5m2."
|
||||
)
|
||||
else:
|
||||
dtype = current_platform.fp8_dtype()
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
eps = 1e-10
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x, x_q, x_s, MXFP8_BLOCK_SIZE, eps, fp8_min, fp8_max, True
|
||||
)
|
||||
x_s = x_s.to(torch.float8_e8m0fnu)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def _xpu_mxfp8_quantize_fake(
|
||||
x: torch.Tensor, dtype: torch.dtype | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if dtype is None:
|
||||
dtype = current_platform.fp8_dtype()
|
||||
|
||||
MXFP8_BLOCK_SIZE = 32
|
||||
|
||||
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
|
||||
x_s = torch.zeros(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
return x.to(dtype), x_s.to(torch.float8_e8m0fnu)
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
@@ -504,6 +546,12 @@ class xpu_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="xpu_mxfp8_quantize",
|
||||
op_func=_xpu_mxfp8_quantize_impl,
|
||||
fake_impl=_xpu_mxfp8_quantize_fake,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
|
||||
|
||||
@@ -205,6 +205,12 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def xpu_mxfp8_quantize(
|
||||
x: torch.Tensor, dtype: torch.dtype | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.xpu_mxfp8_quantize(x, dtype)
|
||||
|
||||
|
||||
class Mxfp8LinearOp:
|
||||
def __init__(self):
|
||||
self.backend = select_mxfp8_linear_backend()
|
||||
|
||||
Reference in New Issue
Block a user