[Attention] Deepseek v3 MLA support with FP8 compute (#12601)
This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -20,6 +20,120 @@ FUSED_LAYER_NAME_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||
int]):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
|
||||
|
||||
# Useful when treating N-dimensional group scaling as extended numpy-style
|
||||
# broadcasting in numpy simply stretches dimensions with an extent of 1 to match
|
||||
# the target shape by repeating the data along that dimension (broadcasting)
|
||||
# , we extend these semantics to say if the extent of a dimension in the
|
||||
# source shape is not 1 and does not match the target shape we repeat each
|
||||
# element along that dimension src_shape[dim] // target_shape[dim] times
|
||||
# example if we have:
|
||||
# a = [[1, 2], and target_shape = (2, 4)
|
||||
# [3, 4]]
|
||||
# then we would expand a to:
|
||||
# a = [[1, 1, 2, 2],
|
||||
# [3, 3, 4, 4]]
|
||||
# NOTE this function this function does not explicitly broadcast dimensions
|
||||
# with an extent of 1, since this can be done implicitly by pytorch
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
t = t.unsqueeze(i + 1)\
|
||||
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
|
||||
.flatten(i, i + 1)
|
||||
return t
|
||||
|
||||
|
||||
# Quantize assuming once scale per group of elements with shape group_shape,
|
||||
# example group shapes:
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
# * (1, -1) for per-row quantization
|
||||
# * (-1, 1) for per-column quantization
|
||||
# * (128, 128) for 128x128 deepseek style block quantization
|
||||
# * (1, 128) for deepseek style activation quantization
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: Tuple[int, int],
|
||||
quant_dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, \
|
||||
"currently `scaled_quantize` only supports floating point dtypes " \
|
||||
"but could be extended to support other dtypes"
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
|
||||
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
|
||||
assert x.ndim == 2
|
||||
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
|
||||
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
|
||||
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
|
||||
|
||||
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
|
||||
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
|
||||
x_blkd_permd = x_blkd_permd.flatten(start_dim=2)
|
||||
|
||||
# Compute scales
|
||||
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax
|
||||
|
||||
# Apply scale and convert form:
|
||||
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
|
||||
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
|
||||
.clamp(min=finfo.min, max=finfo.max)\
|
||||
.reshape(blk_m, blk_n, group_shape[0], group_shape[1])\
|
||||
.permute(0, 2, 1, 3)\
|
||||
.reshape(x.shape)
|
||||
|
||||
return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
# inverses `scaled_quantize`
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[Tuple[int, int]] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
if x_s.ndim == 0: # scalar
|
||||
x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor
|
||||
if x_s.ndim == 1:
|
||||
if group_shape is None:
|
||||
raise AssertionError(
|
||||
"if x_s is 1D tensor, group_shape must be provided otherwise "
|
||||
"its ambiguous which dimension to broadcast x_s to")
|
||||
# unsqueeze the scales for the dimension where we want to broadcast
|
||||
# across the full extent
|
||||
if group_shape[0] == x_q.shape[-2]:
|
||||
x_s = x_s.unsqueeze(-2)
|
||||
elif group_shape[1] == x_q.shape[-1]:
|
||||
x_s = x_s.unsqueeze(-1)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"if x_s is a vector we should be broadcasting it to the full "
|
||||
"extent of one of the dimensions")
|
||||
|
||||
if group_shape is not None:
|
||||
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
|
||||
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
|
||||
x_s = group_broadcast(x_s.to(torch.float32), x_q.shape)
|
||||
return (x_q.to(torch.float32) * x_s).to(out_dtype)
|
||||
|
||||
|
||||
def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
|
||||
Reference in New Issue
Block a user