[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@@ -5,7 +5,10 @@ from typing import Optional, Union
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.platforms import current_platform
from vllm.utils import round_up
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
@@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
return ref_out, ref_scale.view((1, ))
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, Bs: torch.Tensor, block_size,
output_dtype):
def native_w8a8_block_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
compute_type: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and
@@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
`Bs` (float32).
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
A = A.to(compute_type)
B = B.to(compute_type)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
@@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}"
assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}"
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
@@ -152,3 +161,152 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` must be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
DEFAULT_BLOCK_SHAPE = [128, 128]
def per_block_cast_to_fp8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def per_block_cast_to_int8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
) -> torch.Tensor:
if scale is not None:
f32 = torch.float32
if per_act_token_quant or block_shape is None:
return (t.to(f32) * scale).to(out_dtype)
else:
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
else:
return t.to(out_dtype)
def native_batched_masked_quant_matmul(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
num_expert_tokens: torch.Tensor,
A_scale: Optional[torch.Tensor] = None,
B_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
if A.dtype.itemsize == 1 and block_shape is not None:
assert A_scale is not None and B_scale is not None
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
block_shape, C.dtype)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
elif A.dtype.itemsize == 1 and block_shape is None:
assert A_scale is not None and B_scale is not None
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
C[e, :num_tokens, :] = (
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
else:
assert A_scale is None
assert B_scale is None
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C