[Hardware/NVIDIA/Kernel] Enable nvidia/DeepSeek-R1-FP4 Model (#16362)
This commit is contained in:
@@ -745,10 +745,11 @@ def get_cutlass_moe_mm_data(
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
"""
|
||||
torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2,
|
||||
input_permutation, output_permutation,
|
||||
num_experts, n, k)
|
||||
return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts, n, k)
|
||||
|
||||
|
||||
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||
@@ -767,9 +768,41 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||
MMs used in the fused MoE operation.
|
||||
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
|
||||
"""
|
||||
torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales,
|
||||
b_scales, expert_offsets, problem_sizes,
|
||||
a_strides, b_strides, c_strides)
|
||||
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
|
||||
a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides,
|
||||
c_strides)
|
||||
|
||||
|
||||
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
||||
a_scales: torch.Tensor, b_scales: torch.Tensor,
|
||||
alphas: torch.Tensor, problem_sizes: torch.Tensor,
|
||||
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
|
||||
out_dtype: torch.dtype, device: torch.device):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
the gemms for each combination based on the specified problem sizes.
|
||||
|
||||
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
||||
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
||||
input and expert weights.
|
||||
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||
each expert begins its computation. The number of tokens
|
||||
computed with expert E is expert_offsets[E + 1] -
|
||||
expert_offsets[E] And the sf_size per expert is
|
||||
sf_offset[E+1] - sf_offset[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
"""
|
||||
m_topk = a_tensors.shape[0]
|
||||
n = b_tensors.shape[1]
|
||||
c_shape = (m_topk, n)
|
||||
c = torch.empty(c_shape, device=device, dtype=out_dtype)
|
||||
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
|
||||
b_scales, alphas, problem_sizes,
|
||||
expert_offsets, sf_offsets)
|
||||
return c.to(out_dtype)
|
||||
|
||||
|
||||
# aqlm
|
||||
@@ -960,6 +993,57 @@ def scaled_fp4_quant(
|
||||
return output, output_scale
|
||||
|
||||
|
||||
def scaled_fp4_experts_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
MAX_TOKENS_PER_EXPERT: int = 163840,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
packed MoE Inputs.
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
expert_map: The expert map tensor
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
expert_offsets: The expert offsets tensor
|
||||
blockscale_offsets: The blockscale offsets tensor
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4
|
||||
output_scales: The blockscale tensor in FP8-E4M3
|
||||
"""
|
||||
assert not current_platform.is_rocm()
|
||||
assert input_tensor.ndim == 2, (
|
||||
f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')
|
||||
|
||||
input_tensor = input_tensor[
|
||||
expert_map] if expert_map is not None else input_tensor
|
||||
m_numtopk, k = input_tensor.shape
|
||||
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for"
|
||||
f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}")
|
||||
scales_k = k // 16
|
||||
padded_k = (scales_k + (4 - 1)) // 4
|
||||
|
||||
# output is uint8 and packed fp4 values
|
||||
output = torch.empty(m_numtopk,
|
||||
k // 2,
|
||||
device=input_tensor.device,
|
||||
dtype=torch.uint8)
|
||||
output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device)
|
||||
torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor,
|
||||
input_global_scale, expert_offsets,
|
||||
blockscale_offsets)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
# fp8
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
@@ -36,7 +36,7 @@ if HAS_TRITON:
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8)
|
||||
cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
@@ -48,4 +48,5 @@ if HAS_TRITON:
|
||||
"get_config_file_name",
|
||||
"grouped_topk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
""" CUTLASS based Fused MoE kernels."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||
@@ -178,3 +179,126 @@ def cutlass_moe_fp8(
|
||||
if not apply_router_weight_on_input:
|
||||
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
return c2.sum(dim=1)
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
MAX_TOKENS_PER_EXPERT = 65536
|
||||
|
||||
|
||||
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
|
||||
device: torch.device):
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
# Gemm 1
|
||||
a: Input tensor: [m, k] (half/bfloat16)
|
||||
a1_gscale: Activation scale per expert: [e] (float32)
|
||||
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
|
||||
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
|
||||
(Note: `n` is the up projection output dim, `k` is the input dim in
|
||||
full precision)
|
||||
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
|
||||
(Block size = 16 for NVFP4)
|
||||
|
||||
# Gemm 2
|
||||
a2_gscale: Activation scale per expert: [e]
|
||||
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
|
||||
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
|
||||
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
|
||||
|
||||
topk_weights: [m, topk] dtype: float8
|
||||
topk_ids: [m, topk] dtype: float8
|
||||
|
||||
m, n, k: Unquantized weight shapes, dtype: int
|
||||
e: number of experts, dtype: int
|
||||
|
||||
assumes that topk < k < n to satisfy - up/down projection expectations.
|
||||
"""
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
|
||||
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
|
||||
assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3
|
||||
and w2_blockscale.ndim
|
||||
== 3), ("All Weights must be of rank 3 for cutlass_moe_fp4")
|
||||
m_a, k_a = a.shape
|
||||
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
||||
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
||||
|
||||
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
|
||||
" between weights.")
|
||||
assert (k_a // 2 == half_k_w1
|
||||
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
||||
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
|
||||
"expected `n`")
|
||||
assert (m == m_a), "input shape mismatch"
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
|
||||
== m), ("topk must be provided for each row of a")
|
||||
assert (m <= MAX_TOKENS_PER_EXPERT), (
|
||||
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m = {m}")
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.shape[1]
|
||||
|
||||
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,2n,k))
|
||||
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,n,k))
|
||||
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
# problem shapes should have [m, n, k]
|
||||
# Note that problem sizes are based on logical number of elements.
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, e, n, k)
|
||||
|
||||
tokens_per_expert = problem_sizes1[:, 0]
|
||||
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
|
||||
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
|
||||
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
|
||||
|
||||
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
|
||||
a,
|
||||
a1_gscale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
expert_map=a_map,
|
||||
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
|
||||
|
||||
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
w1_blockscale, w1_alphas, problem_sizes1,
|
||||
expert_offsets[:-1], blockscale_offsets[:-1],
|
||||
out_dtype, device)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
# hidden size dimension is split to one halfpytho sized tensor.
|
||||
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
intermediate,
|
||||
a2_gscale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
|
||||
|
||||
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||
blockscale_offsets[:-1], out_dtype, device)
|
||||
del int_fp4, int_blockscale
|
||||
out = (c2[c_map].view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
|
||||
@@ -643,7 +643,7 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
@@ -697,8 +697,9 @@ class FusedMoE(torch.nn.Module):
|
||||
# this is needed for compressed-tensors only
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
if ("compressed" in quant_method_name.lower()
|
||||
and param.data[expert_id] != 1
|
||||
and (param.data[expert_id] - loaded_weight).abs() > 1e-5):
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
@@ -718,6 +719,22 @@ class FusedMoE(torch.nn.Module):
|
||||
tp_rank=self.tp_rank)
|
||||
return
|
||||
|
||||
if "ModelOpt" in quant_method_name:
|
||||
if ('weight_scale_2' in weight_name
|
||||
or 'input_scale' in weight_name):
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
elif "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return
|
||||
|
||||
# Case weight scales, zero_points and offset
|
||||
if ("scale" in weight_name or "zero" in weight_name
|
||||
or "offset" in weight_name):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
|
||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -210,25 +212,37 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration.")
|
||||
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
|
||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||
group_size = quant_config["group_size"]
|
||||
exclude_modules = quant_config["exclude_modules"]
|
||||
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
||||
if ("group_size" and "kv_cache_quant_algo"
|
||||
and "exclude_modules") not in quant_config:
|
||||
raise ValueError("NVFP4 quantization requires group size and "
|
||||
"kv_cache_quant_algo specified in "
|
||||
"hf_quant_config.json")
|
||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||
group_size = quant_config["group_size"]
|
||||
exclude_modules = quant_config["exclude_modules"]
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: List):
|
||||
import re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.exclude_modules):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules)
|
||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptNvFp4FusedMoE(self)
|
||||
return None
|
||||
|
||||
|
||||
@@ -409,3 +423,235 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"""
|
||||
MoE Method for FP4 Quantization.
|
||||
Args:
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
layer.quant_config = self.quant_config
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# GEMM 1
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
# GEMM 2
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition //
|
||||
self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
|
||||
|
||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
|
||||
w2_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# GEMM 1
|
||||
|
||||
assert torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
|
||||
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
|
||||
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
||||
torch.float32)
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(
|
||||
layer.w13_weight_scale)
|
||||
|
||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
layer.g2_alphas = Parameter(
|
||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
"supported for ModelOptNvFp4FusedMoE.")
|
||||
assert expert_map is None, ("Expert Parallelism /expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device).to(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user