nuke vllm because this keep confusing people
This commit is contained in:
69
Dockerfile
69
Dockerfile
@@ -26,72 +26,3 @@ RUN cd /root/nvfp4-megamoe-kernel && pip install -e .
|
||||
COPY cutedsl/ /root/nvfp4-megamoe-kernel/cutedsl/
|
||||
|
||||
ENV PYTHONPATH="/root/nvfp4-megamoe-kernel:${PYTHONPATH}"
|
||||
|
||||
# Patch vLLM — overwrite model files and register architecture
|
||||
ARG VLLM_MODELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models
|
||||
ARG VLLM_LAYERS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers
|
||||
ARG VLLM_QUANT_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization
|
||||
ARG VLLM_FUSED_MOE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe
|
||||
ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader
|
||||
|
||||
# Core model patches
|
||||
COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
|
||||
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
|
||||
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
|
||||
|
||||
# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell)
|
||||
# The nightly image has all MHC in layers/mhc.py (imports tilelang at top level).
|
||||
# Our replacement is pure PyTorch — no tilelang dependency at all.
|
||||
COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py
|
||||
|
||||
# CSA/HCA attention kernel (replaces FlashMLA on Blackwell)
|
||||
COPY vllm/patches/layers/csa_attention.py ${VLLM_LAYERS_DIR}/csa_attention.py
|
||||
|
||||
# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel)
|
||||
ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4
|
||||
COPY vllm/kernels/linear/nvfp4/cutedsl.py ${VLLM_NVFP4_DIR}/cutedsl.py
|
||||
|
||||
# Patch KV cache utils to handle DeepseekV4 SWA page sizes > MLA page sizes
|
||||
# (SWA layers have larger page sizes than compressed MLA layers on Blackwell)
|
||||
ARG VLLM_CORE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/core
|
||||
COPY vllm/patches/patch_kv_cache_utils.py /tmp/patch_kv_cache_utils.py
|
||||
RUN python3 /tmp/patch_kv_cache_utils.py ${VLLM_CORE_DIR}/kv_cache_utils.py && rm /tmp/patch_kv_cache_utils.py
|
||||
|
||||
# Patch SWA cache and Indexer cache for Blackwell (no FlashMLA alignment)
|
||||
ARG VLLM_SPARSE_SWA_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla
|
||||
ARG VLLM_LAYERS_DIR2=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers
|
||||
COPY vllm/patches/patch_swa_cache.py /tmp/patch_swa_cache.py
|
||||
RUN python3 /tmp/patch_swa_cache.py ${VLLM_SPARSE_SWA_DIR}/sparse_swa.py && rm /tmp/patch_swa_cache.py
|
||||
COPY vllm/patches/patch_indexer_cache.py /tmp/patch_indexer_cache.py
|
||||
RUN python3 /tmp/patch_indexer_cache.py ${VLLM_LAYERS_DIR2}/deepseek_v4_attention.py && rm /tmp/patch_indexer_cache.py
|
||||
COPY vllm/patches/patch_compressor_cache.py /tmp/patch_compressor_cache.py
|
||||
RUN python3 /tmp/patch_compressor_cache.py ${VLLM_LAYERS_DIR2}/deepseek_compressor.py && rm /tmp/patch_compressor_cache.py
|
||||
|
||||
# Debug: print layer name mismatch
|
||||
ARG VLLM_WORKER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/v1/worker
|
||||
COPY vllm/patches/patch_debug_layers.py /tmp/patch_debug_layers.py
|
||||
RUN python3 /tmp/patch_debug_layers.py ${VLLM_WORKER_DIR}/gpu_model_runner.py && rm /tmp/patch_debug_layers.py
|
||||
|
||||
# Register CuTeDSL kernel in vLLM's linear kernel selection
|
||||
ARG VLLM_LINEAR_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear
|
||||
COPY vllm/patches/register_cutedsl_kernel.py /tmp/register_cutedsl_kernel.py
|
||||
RUN python3 /tmp/register_cutedsl_kernel.py ${VLLM_LINEAR_DIR}/__init__.py && rm /tmp/register_cutedsl_kernel.py
|
||||
|
||||
# Config patches (add cutedsl to MoEBackend)
|
||||
ARG VLLM_CONFIG_DIR=/usr/local/lib/python3.12/dist-packages/vllm/config
|
||||
COPY vllm/patches/kernel.py ${VLLM_CONFIG_DIR}/kernel.py
|
||||
|
||||
# NVFP4 MoE backend registration
|
||||
COPY vllm/patches/fused_moe/oracle/nvfp4.py ${VLLM_FUSED_MOE_DIR}/oracle/nvfp4.py
|
||||
COPY vllm/patches/fused_moe/experts/cutedsl_moe.py ${VLLM_FUSED_MOE_DIR}/experts/cutedsl_moe.py
|
||||
|
||||
# Register DeepseekV4ForCausalLM model architecture (if not already in upstream)
|
||||
RUN grep -q '"DeepseekV4ForCausalLM"' ${VLLM_MODELS_DIR}/registry.py || \
|
||||
sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \
|
||||
${VLLM_MODELS_DIR}/registry.py
|
||||
|
||||
# Verify
|
||||
RUN python3 -c "import torch; print(f'PyTorch {torch.__version__} OK')" && \
|
||||
python3 -c "import vllm; print('vLLM OK')" && \
|
||||
python3 -c "import nvfp4_megamoe_kernel; print('NVFP4 kernel OK')" && \
|
||||
python3 -c "import cutlass; print('CuTeDSL OK')"
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""CuTeDSL NVFP4 Quantization Method for vLLM
|
||||
|
||||
Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM.
|
||||
After process_weights_after_loading, the module's quant_method is swapped
|
||||
to CuTeDSLNvfp4LinearMethod which routes forward() through CuTeDSL
|
||||
via torch.library.custom_op (opaque to torch.compile).
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
|
||||
|
||||
class CuTeDSLNvfp4Method(LinearMethodBase):
|
||||
"""Pre-processing quant method that sets up CuTeDSL runners.
|
||||
|
||||
Installed on NVFP4 linear layers before process_weights_after_loading.
|
||||
When vLLM calls process_weights_after_loading, this method:
|
||||
1. Reads NVFP4 weights (uint8, float8 block scales, float32 global scales)
|
||||
2. Converts to CuTeDSL format
|
||||
3. Creates CuTeDSLNvfp4Linear runner
|
||||
4. Stores runner on the module
|
||||
5. Frees original weight/scale params
|
||||
6. Replaces quant_method with CuTeDSLNvfp4LinearMethod
|
||||
"""
|
||||
|
||||
def __init__(self, is_fused: bool = False):
|
||||
"""
|
||||
Args:
|
||||
is_fused: True for MergedColumnParallelLinear with two sub-projections
|
||||
(e.g., fused_wqa_wkv with q_a + kv, or gate_up_proj with gate + up).
|
||||
Handles dual weight_scale_2 the same way as MoE L1:
|
||||
normalize to max(gs1, gs2), fold ratio into block scales.
|
||||
"""
|
||||
self.is_fused = is_fused
|
||||
|
||||
def create_weights(self, layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size, output_size,
|
||||
params_dtype, **extra_weight_attrs):
|
||||
# We don't create weights — ModelOptNvFp4LinearMethod already did that.
|
||||
# This method is only installed after weight loading.
|
||||
pass
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
|
||||
w_uint8 = layer.weight.data # (out, in//2) uint8 packed E2M1
|
||||
device = w_uint8.device
|
||||
out_features = w_uint8.shape[0]
|
||||
in_features = w_uint8.shape[1] * 2 # 2 FP4 values per uint8
|
||||
|
||||
# Convert uint8 → float4_e2m1fn_x2, then permute to (K_packed, N)
|
||||
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
|
||||
# Block scales: (N, K_sf) → (K_sf, N)
|
||||
sf = layer.weight_scale.data
|
||||
if sf.dtype != torch.float8_e4m3fn:
|
||||
sf = sf.to(torch.float8_e4m3fn)
|
||||
sf = sf.permute(1, 0).contiguous()
|
||||
|
||||
# Global scale
|
||||
weight_scale_2 = layer.weight_scale_2.data
|
||||
if self.is_fused and weight_scale_2.numel() == 2:
|
||||
# Dual global scales (fused_wqa_wkv: q_a + kv, gate_up: gate + up)
|
||||
gs1 = weight_scale_2[0].item()
|
||||
gs2 = weight_scale_2[1].item()
|
||||
gs = max(gs1, gs2)
|
||||
|
||||
# Fold ratio into block scales via float32 round-trip
|
||||
if gs1 != gs2:
|
||||
sf_f32 = sf.float()
|
||||
# After permute to (K_sf, N): first sub-projection's output
|
||||
# columns, then second sub-projection's output columns
|
||||
logical_widths = getattr(layer, 'logical_widths', None)
|
||||
if logical_widths is not None and len(logical_widths) == 2:
|
||||
split_point = logical_widths[0]
|
||||
else:
|
||||
split_point = out_features // 2
|
||||
sf_f32[:, :split_point] *= (gs1 / gs)
|
||||
sf_f32[:, split_point:] *= (gs2 / gs)
|
||||
sf = sf_f32.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
gs = weight_scale_2.max().item()
|
||||
|
||||
# Create CuTeDSL runner
|
||||
runner = CuTeDSLNvfp4Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
device=device,
|
||||
)
|
||||
runner.fp4 = [w_fp4]
|
||||
runner.sf = [sf]
|
||||
runner.gs = [gs]
|
||||
runner.finalize_weights()
|
||||
|
||||
# Register runner in global registry (for torch.library.custom_op)
|
||||
layer._cutedsl_runner_id = register_runner(runner)
|
||||
layer._cutedsl_out_features = out_features
|
||||
|
||||
# Warmup: compute activation global scale from sample data.
|
||||
# The checkpoint's input_scale is a calibration-time value that does NOT
|
||||
# match what quantize_activation_nvfp4 expects at runtime. Using it
|
||||
# produces garbage output (empty EOS tokens). The correct approach is
|
||||
# a warmup forward pass that measures the actual activation distribution.
|
||||
# Use only 1 token to minimize GPU memory overhead during weight loading.
|
||||
with torch.no_grad():
|
||||
sample = torch.randn(1, in_features,
|
||||
dtype=torch.bfloat16, device=device) * 2.0
|
||||
runner.compute_activation_global_scale(sample)
|
||||
del sample
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Replace weight with dummy BF16 (needed by vLLM module introspection)
|
||||
# Replace weight with a GPU dummy (some vLLM code paths like
|
||||
# torch.mm(compressor.weight.T) expect weight on GPU).
|
||||
layer.weight = torch.nn.Parameter(
|
||||
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Free original NVFP4 params
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"weight_global_scale", "alpha", "weight_scale_inv"):
|
||||
if hasattr(layer, attr):
|
||||
try:
|
||||
delattr(layer, attr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Swap quant method to the forward-only one
|
||||
layer.quant_method = CuTeDSLNvfp4LinearMethod()
|
||||
|
||||
def apply(self, layer, x, bias=None):
|
||||
raise NotImplementedError(
|
||||
"CuTeDSLNvfp4Method should be replaced by "
|
||||
"CuTeDSLNvfp4LinearMethod after process_weights_after_loading"
|
||||
)
|
||||
|
||||
|
||||
class CuTeDSLNvfp4LinearMethod(LinearMethodBase):
|
||||
"""Forward path: BF16 input → CuTeDSL NVFP4 GEMM → BF16 output."""
|
||||
|
||||
def create_weights(self, layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size, output_size,
|
||||
params_dtype, **extra_weight_attrs):
|
||||
pass
|
||||
|
||||
def apply(self, layer, x: torch.Tensor, bias=None) -> torch.Tensor:
|
||||
result = nvfp4_linear_gemm(
|
||||
x, layer._cutedsl_runner_id, layer._cutedsl_out_features,
|
||||
)
|
||||
if bias is not None:
|
||||
result = result + bias
|
||||
return result
|
||||
@@ -1,133 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CuTeDSL NVFP4 Linear Kernel for vLLM.
|
||||
|
||||
Registers as an NvFp4LinearKernel so that vLLM kernel selection
|
||||
(init_nvfp4_linear_kernel) picks it up on Blackwell GPUs.
|
||||
Routes NVFP4 GEMM through CuTeDSL's MLIR-compiled grouped GEMM.
|
||||
|
||||
Uses torch.library.custom_op to make Dynamo (torch.compile) treat the
|
||||
GEMM as opaque. The runner's _run_impl is already cudagraph-safe.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
|
||||
from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
"""NVFP4 GEMM via the CuTeDSL framework (Blackwell SM100+)."""
|
||||
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
cap = compute_capability or current_platform.get_device_capability()
|
||||
if cap is not None and cap.major >= 10:
|
||||
return True, None
|
||||
return False, "CuTeDSL NVFP4 requires SM100+ (Blackwell)"
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Convert NVFP4 weights into CuTeDSL kernel format."""
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
|
||||
w_uint8 = layer.weight.data
|
||||
device = w_uint8.device
|
||||
out_features = w_uint8.shape[0]
|
||||
in_features = w_uint8.shape[1] * 2
|
||||
|
||||
w_fp4 = w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
|
||||
sf = layer.weight_scale.data
|
||||
if sf.dtype != torch.float8_e4m3fn:
|
||||
sf = sf.to(torch.float8_e4m3fn)
|
||||
sf = sf.permute(1, 0).contiguous()
|
||||
|
||||
gs = layer.weight_global_scale.data.item()
|
||||
|
||||
if layer.weight_global_scale.numel() == 2:
|
||||
gs0 = layer.weight_global_scale[0].item()
|
||||
gs1 = layer.weight_global_scale[1].item()
|
||||
gs = max(gs0, gs1)
|
||||
if gs0 != gs1:
|
||||
sf_f32 = sf.float()
|
||||
logical_widths = getattr(layer, 'logical_widths', None)
|
||||
if logical_widths is not None and len(logical_widths) == 2:
|
||||
split_point = logical_widths[0]
|
||||
else:
|
||||
split_point = out_features // 2
|
||||
sf_f32[:, :split_point] *= (gs0 / gs)
|
||||
sf_f32[:, split_point:] *= (gs1 / gs)
|
||||
sf = sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
runner = CuTeDSLNvfp4Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
device=str(device),
|
||||
)
|
||||
runner.fp4 = [w_fp4]
|
||||
runner.sf = [sf]
|
||||
runner.gs = [gs]
|
||||
runner.finalize_weights()
|
||||
|
||||
# Warmup: compute activation global scale from sample data.
|
||||
# The checkpoint's input_scale is a calibration-time value that does NOT
|
||||
# match what quantize_activation_nvfp4 expects at runtime. Using it
|
||||
# produces garbage output (empty EOS tokens). The correct approach is
|
||||
# a warmup forward pass that measures the actual activation distribution.
|
||||
# Use only 1 token to minimize GPU memory overhead during weight loading.
|
||||
with torch.no_grad():
|
||||
sample = torch.randn(
|
||||
1, in_features,
|
||||
dtype=torch.bfloat16, device=str(device),
|
||||
) * 2.0
|
||||
runner.compute_activation_global_scale(sample)
|
||||
del sample
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Register the runner and store the ID (not the runner itself)
|
||||
layer._cutedsl_runner_id = register_runner(runner)
|
||||
layer._cutedsl_out_features = out_features
|
||||
|
||||
# Replace weight with a GPU dummy (some vLLM code paths like
|
||||
# torch.mm(compressor.weight.T) expect weight on GPU).
|
||||
layer.weight = torch.nn.Parameter(
|
||||
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for attr in ("weight_scale", "weight_global_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"alpha", "weights_padding_cols", "weight_scale_2",
|
||||
"input_scale"):
|
||||
if hasattr(layer, attr):
|
||||
try:
|
||||
delattr(layer, attr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
result = nvfp4_linear_gemm(
|
||||
x,
|
||||
layer._cutedsl_runner_id,
|
||||
layer._cutedsl_out_features,
|
||||
)
|
||||
if bias is not None:
|
||||
result = result + bias
|
||||
return result
|
||||
@@ -1,536 +0,0 @@
|
||||
"""
|
||||
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
|
||||
|
||||
CUDA-graph-compatible design:
|
||||
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
|
||||
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
|
||||
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
|
||||
- Extra slots (beyond real tokens) are zero and contribute nothing to output
|
||||
- Fixed-shape tensors throughout the forward pass
|
||||
|
||||
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
|
||||
During capture, num_tokens equals the budget — all shapes are fixed.
|
||||
During replay, inputs are padded to the budget size. Our runner always
|
||||
processes max_slots = budget * top_k rows; padding rows are zeros.
|
||||
|
||||
Dynamo compatibility: uses torch.library.custom_op via cutedsl.custom_ops
|
||||
so torch.compile (fullgraph mode) treats the GEMM as an opaque black box.
|
||||
The runner's _run_impl is already cudagraph-safe.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
make_b_k_major,
|
||||
assemble_scales_3d_side,
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
|
||||
|
||||
class CuTeDSLMoERunner:
|
||||
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
||||
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts, hidden_size, intermediate_size,
|
||||
max_num_tokens=8192, top_k=8, device="cuda",
|
||||
experts_start_idx=0):
|
||||
self.num_experts = num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.top_k = top_k
|
||||
self.device = device
|
||||
self.experts_start_idx = experts_start_idx
|
||||
self._swiglu_limit = None # Set via set_swiglu_limit()
|
||||
|
||||
# Weight storage (set before _ensure_stacked)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Stacked weight tensors (set in _ensure_stacked)
|
||||
self._l1_mat_b = None
|
||||
self._l2_mat_b = None
|
||||
self._l1_scale_b = None
|
||||
self._l2_scale_b = None
|
||||
self._l1_gsb = None
|
||||
self._l2_gsb = None
|
||||
|
||||
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
|
||||
# Overridden in finalize_weights with checkpoint input_scale or warmup value
|
||||
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
|
||||
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
||||
self._token_indices = None
|
||||
self._expert_id_range = None
|
||||
self._expert_offsets_buf = None
|
||||
self._per_expert_scale_bufs_l1 = None
|
||||
self._per_expert_scale_bufs_l2 = None
|
||||
self._padded_x_sf_buf_l1 = None
|
||||
self._padded_x_sf_buf_l2 = None
|
||||
self._l1_gsa_buf = None
|
||||
self._l2_gsa_buf = None
|
||||
self._output_buf = None
|
||||
self._row_indices_buf = None
|
||||
self._padded_hidden_buf = None
|
||||
self._padded_activated_buf = None # unused, using shared
|
||||
self._padded_expert_offsets_buf = None
|
||||
self._max_chunks_per_expert = cutedsl_ceil_div(
|
||||
self.max_num_tokens * self.top_k, self.num_experts * 128
|
||||
)
|
||||
self._buffers_allocated = False
|
||||
|
||||
def set_swiglu_limit(self, limit: float | None):
|
||||
"""Set the swiglu_limit for activation clamping."""
|
||||
self._swiglu_limit = limit
|
||||
|
||||
def _fill_token_indices(self):
|
||||
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
|
||||
|
||||
Builds on CPU first, then copies to GPU, to ensure correctness
|
||||
regardless of CuTeDSL JIT GPU memory corruption.
|
||||
"""
|
||||
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
|
||||
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
||||
self._token_indices.copy_(cpu_indices)
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
|
||||
# Per-expert scale buffers: separate L1/L2 since K_sf differs
|
||||
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
|
||||
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
|
||||
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
|
||||
|
||||
self._per_expert_scale_bufs_l1 = [
|
||||
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
self._per_expert_scale_bufs_l2 = [
|
||||
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
# Initialize shared buffers dict (if not already)
|
||||
device_key = str(self.device)
|
||||
if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'):
|
||||
CuTeDSLMoERunner._shared_padded_bufs = {}
|
||||
if device_key not in CuTeDSLMoERunner._shared_padded_bufs:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key] = {}
|
||||
|
||||
# Padded x_sf buffers: SHARED across all runners (not per-layer)
|
||||
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
|
||||
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
||||
'xsf_l1': torch.zeros(
|
||||
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'xsf_l2': torch.zeros(
|
||||
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'output': torch.zeros(
|
||||
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
})
|
||||
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
|
||||
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
|
||||
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
|
||||
|
||||
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
||||
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Row indices for scale assembly (max_num_tokens * top_k slots)
|
||||
self._row_indices_buf = torch.arange(
|
||||
self.max_num_tokens * self.top_k, device=self.device
|
||||
)
|
||||
|
||||
# Padded hidden/activated: SHARED across all runners (not per-layer)
|
||||
max_rows_per_expert = self._max_chunks_per_expert * 128
|
||||
padded_max_slots = self.num_experts * max_rows_per_expert
|
||||
if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
||||
'hidden': torch.zeros(
|
||||
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'hidden_fp4': torch.zeros(
|
||||
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2),
|
||||
'activated': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'activated_fp4': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2),
|
||||
})
|
||||
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
|
||||
|
||||
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
max_rows_per_expert = self._max_chunks_per_expert * 128
|
||||
self._padded_expert_offsets_buf[1:] = torch.arange(
|
||||
1, self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
) * max_rows_per_expert
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_stacked(self):
|
||||
if self._l1_mat_b is not None:
|
||||
return
|
||||
|
||||
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Allocate buffers AFTER JIT compilation
|
||||
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
|
||||
# tensors allocated before/during compilation may be zeroed)
|
||||
#
|
||||
# _token_indices: GPU tensor for cudagraph compatibility.
|
||||
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
|
||||
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
|
||||
# is triggered on the first run() call; we refill _token_indices after
|
||||
# that first call via the _needs_token_refill flag.
|
||||
self._token_indices = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
|
||||
|
||||
self._expert_id_range = torch.arange(
|
||||
self.num_experts, dtype=torch.int32
|
||||
).to(self.device)
|
||||
self._expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._allocate_buffers()
|
||||
|
||||
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
|
||||
self.l1_fp4 = l1_fp4
|
||||
self.l1_sf = l1_sf
|
||||
self.l1_gs = l1_gs
|
||||
self.l2_fp4 = l2_fp4
|
||||
self.l2_sf = l2_sf
|
||||
self.l2_gs = l2_gs
|
||||
self._l1_mat_b = None
|
||||
|
||||
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
||||
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
||||
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
||||
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
|
||||
l1_w_t = l1_w.T
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
|
||||
self.l1_fp4.append(w_fp4)
|
||||
self.l1_sf.append(w_sf)
|
||||
self.l1_gs.append(w_gs)
|
||||
l2_w_t = l2_w.T
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
|
||||
self.l2_fp4.append(w_fp4)
|
||||
self.l2_sf.append(w_sf)
|
||||
self.l2_gs.append(w_gs)
|
||||
self._l1_mat_b = None
|
||||
|
||||
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
|
||||
padded_expert_offsets,
|
||||
padded_x_sf_buf, per_expert_bufs):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
|
||||
|
||||
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
|
||||
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
|
||||
|
||||
The buffer is 128-row aligned per expert (from padded_expert_offsets),
|
||||
so the full-buffer swizzle produces the correct layout. The GEMM reads
|
||||
scale_a using padded_expert_offsets, matching the scatter layout.
|
||||
"""
|
||||
K_sf = x_sf.shape[1]
|
||||
padded_x_sf = padded_x_sf_buf
|
||||
padded_x_sf.zero_()
|
||||
|
||||
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
|
||||
total_rows = x_sf.shape[0]
|
||||
row_indices = self._row_indices_buf[:total_rows]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
||||
padded_x_sf[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
|
||||
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
|
||||
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
|
||||
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
|
||||
rows = padded_x_sf.shape[0]
|
||||
cols = padded_x_sf.shape[1]
|
||||
R = rows // 128
|
||||
C = cols // 4
|
||||
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
|
||||
return swizzled.reshape(rows, cols)
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
||||
"""Compute activation global scales from a warmup forward pass.
|
||||
|
||||
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
|
||||
to ensure kernel JIT happens with the same layout, and L2 gs is computed
|
||||
from actual L1 output (not an approximation).
|
||||
"""
|
||||
self._ensure_stacked()
|
||||
device = hidden_states_sample.device
|
||||
num_tokens = hidden_states_sample.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
# Build slot mapping (same as run())
|
||||
flat_ids = topk_ids.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
slot_hidden = hidden_states_sample[sorted_token_ids]
|
||||
|
||||
# L1: get exact gs from quantize_to_nvfp4
|
||||
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
|
||||
|
||||
# Quantize slot_hidden for GEMM
|
||||
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
|
||||
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
|
||||
# Compute padded_dst (same as run())
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# Scatter x_fp4 into padded layout
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Extract real token outputs
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
|
||||
# L2: get exact gs from SiLU(gate)*up
|
||||
gate = l1_out_real[:, :self.intermediate_size]
|
||||
up = l1_out_real[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
|
||||
self._l1_activation_global_scale = l1_gs
|
||||
self._l2_activation_global_scale = l2_gs
|
||||
|
||||
|
||||
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Forward: route tokens to experts, GEMM, combine.
|
||||
|
||||
Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile
|
||||
treats this as an opaque op. The custom op calls _run_impl internally.
|
||||
"""
|
||||
if not hasattr(self, '_runner_id'):
|
||||
self._runner_id = register_runner(self)
|
||||
return nvfp4_moe_gemm(
|
||||
hidden_states, topk_weights, topk_ids,
|
||||
self._runner_id, self.hidden_size,
|
||||
)
|
||||
|
||||
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Run the NVFP4 MoE forward pass.
|
||||
|
||||
Handles global→local expert ID remapping for expert parallelism.
|
||||
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
||||
|
||||
Each expert's slots are padded to multiples of 128 for the GEMM.
|
||||
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
|
||||
scale_a is produced at those same offsets.
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
self._ensure_stacked()
|
||||
|
||||
# -- Remap global expert IDs to local IDs --
|
||||
local_ids = topk_ids - self.experts_start_idx
|
||||
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
|
||||
safe_ids = local_ids.clamp(0, self.num_experts - 1)
|
||||
safe_weights = topk_weights * local_mask.float()
|
||||
|
||||
# -- Build slot mapping --
|
||||
flat_ids = safe_ids.reshape(-1)
|
||||
flat_weights = safe_weights.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_weights = flat_weights[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
|
||||
# Expert offsets (real token counts)
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
# Pad each expert to 128-row alignment (GPU-only computation)
|
||||
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
total_padded_slots = padded_expert_offsets[self.num_experts]
|
||||
|
||||
# -- Gather hidden states into slot order, compute padded_dst --
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# === L1: gate + up ===
|
||||
# Quantize slot_hidden (sorted tokens), NOT padded_hidden.
|
||||
# padded_hidden is padded with zeros; quantizing it produces
|
||||
# x_sf rows at padded positions, but x_sf[:num_slots] would
|
||||
# only get scales for the first num_slots PADDED rows (expert 0),
|
||||
# not the scattered token positions. Quantizing slot_hidden
|
||||
# gives x_sf with num_slots rows (one per token), which the
|
||||
# scale assembly correctly scatters into padded layout.
|
||||
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
# Scatter x_fp4 into padded layout for the GEMM
|
||||
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Extract real token outputs from padded GEMM output
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
|
||||
# === SiLU(gate) * up (with swiglu_limit clamp) ===
|
||||
gate = l1_out_real[:, :self.intermediate_size]
|
||||
up = l1_out_real[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
# Apply DeepSeek-V4 swiglu_limit: clamp both silu(gate) and up
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
|
||||
# === L2: down ===
|
||||
# Quantize activated (per-token), scatter into padded FP4 buffer
|
||||
slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4(
|
||||
activated, self._l2_activation_global_scale
|
||||
)
|
||||
padded_activated_fp4 = self._shared_bufs['activated_fp4']
|
||||
padded_activated_fp4.view(torch.uint8).zero_()
|
||||
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
||||
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
l2_out_real = l2_out[padded_dst]
|
||||
|
||||
# === Scatter -> final output ===
|
||||
y = self._output_buf[:num_tokens]
|
||||
y.zero_()
|
||||
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
|
||||
y.scatter_add_(
|
||||
0,
|
||||
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
|
||||
weighted_out,
|
||||
)
|
||||
|
||||
# Refill _token_indices after GEMM JIT on first call
|
||||
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
|
||||
if self._needs_token_refill:
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = False
|
||||
|
||||
return y
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,308 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CuTeDSL NVFP4 MoE experts for DeepSeek-V4.
|
||||
|
||||
Integrates the CuTeDSL NVFP4 grouped GEMM kernel into vLLM's FusedMoE
|
||||
modular kernel framework. This is the proper integration path — no
|
||||
monkey-patching, no post-load surgery.
|
||||
|
||||
The CuTeDSL kernel is a Python-based CUTLASS kernel compiled via MLIR → PTX.
|
||||
It handles:
|
||||
- L1 GEMM (gate + up projections)
|
||||
- SiLU activation with optional swiglu_limit clamping
|
||||
- L2 GEMM (down projection)
|
||||
- All with NVFP4 (float8_e4m3fn block scales + float32 global scales)
|
||||
|
||||
CUDA-graph-safe: all intermediate buffers pre-allocated, no CPU-GPU syncs,
|
||||
no dynamic shapes.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from cutedsl.runner import CuTeDSLMoERunner
|
||||
|
||||
|
||||
class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
||||
"""CuTeDSL NVFP4 MoE experts using the custom CuTeDSL grouped GEMM kernel.
|
||||
|
||||
Uses Standard activation format (non-batched). Handles input quantization
|
||||
internally (expects_unquantized_inputs=True).
|
||||
|
||||
Supports expert parallelism: remaps global→local expert IDs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
assert quant_config.quant_dtype == "nvfp4", (
|
||||
"CuTeDSL MoE only supports nvfp4 quantization, "
|
||||
f"got {quant_config.quant_dtype}"
|
||||
)
|
||||
self.out_dtype = moe_config.in_dtype
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.global_num_experts = moe_config.num_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
self.local_expert_offset = self.ep_rank * self.local_num_experts
|
||||
# max_num_tokens from scheduler config (for buffer pre-allocation)
|
||||
self.max_num_tokens = getattr(moe_config, 'max_num_tokens', 8192)
|
||||
|
||||
# swiglu_limit: read from the FusedMoE layer in process_weights_after_loading
|
||||
self._swiglu_limit = None
|
||||
|
||||
# Runner is created in process_weights_after_loading
|
||||
self._runner: CuTeDSLMoERunner | None = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Convert NVFP4 MoE weights into CuTeDSL kernel format.
|
||||
|
||||
Reads w13/w2 weight tensors from the FusedMoE layer, converts them
|
||||
to the CuTeDSL runner's expected format, and creates the runner.
|
||||
Also folds the activation global scale (input_scale) into the
|
||||
weight global scale (weight_scale_2) as the runner's alpha.
|
||||
"""
|
||||
num_experts = layer.w13_weight.shape[0]
|
||||
hidden_size = self.hidden_dim
|
||||
intermediate_size = self.intermediate_size_per_partition
|
||||
device = layer.w13_weight.device
|
||||
|
||||
# NOTE: For the CuTeDSL kernel, we do NOT fold input_scale into
|
||||
# weight_scale_2. The CuTeDSL runner uses weight global scale
|
||||
# (weight_scale_2) and activation global scale separately.
|
||||
# The activation global scale is computed via warmup before first inference.
|
||||
#
|
||||
# Also, convert_to_nvfp4_moe_kernel_format already inverted input_scale
|
||||
# (1.0 / a13_scale) for the quant config. We undo that inversion here
|
||||
# to get the original input_scale, then use it as initial activation gs.
|
||||
if layer.w13_input_scale is not None and not isinstance(layer.w13_input_scale, float):
|
||||
# input_scale was inverted in convert_to_nvfp4_moe_kernel_format
|
||||
# Original: input_scale = amax / (6.0 * 448.0)
|
||||
# Inverted: 1.0 / input_scale = 6.0 * 448.0 / amax
|
||||
# We need the original for activation gs
|
||||
w13_input_scale_orig = 1.0 / layer.w13_input_scale
|
||||
else:
|
||||
w13_input_scale_orig = None
|
||||
if layer.w2_input_scale is not None and not isinstance(layer.w2_input_scale, float):
|
||||
w2_input_scale_orig = 1.0 / layer.w2_input_scale
|
||||
else:
|
||||
w2_input_scale_orig = None
|
||||
|
||||
# Extract weights from the layer — checkpoint format, no copies yet.
|
||||
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
|
||||
# w2_weight: (E, hidden, intermediate//2) uint8 — down
|
||||
# w13_weight_scale: (E, 2*intermediate, hidden//16) fp8
|
||||
# w2_weight_scale: (E, hidden, intermediate//16) fp8
|
||||
w13_uint8 = layer.w13_weight.data # (E, 2*inter, hidden//2)
|
||||
w2_uint8 = layer.w2_weight.data # (E, hidden, intermediate//2)
|
||||
w13_sf = layer.w13_weight_scale.data # (E, 2*inter, hidden//16) = (E, N, K_sf)
|
||||
w2_sf = layer.w2_weight_scale.data # (E, hidden, intermediate//16) = (E, N, K_sf)
|
||||
w13_gs = layer.w13_weight_scale_2.data # (E,) or (E, 2)
|
||||
w2_gs = layer.w2_weight_scale_2.data # (E,) or (E, 2)
|
||||
|
||||
# View as fp4 — byte-preserving, NO copy
|
||||
l1_fp4 = w13_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
|
||||
l2_fp4 = w2_uint8.view(torch.float4_e2m1fn_x2) # (E, N, K_packed)
|
||||
|
||||
# Ensure scales are float8_e4m3fn (no copy if already correct dtype)
|
||||
if w13_sf.dtype != torch.float8_e4m3fn:
|
||||
w13_sf = w13_sf.to(torch.float8_e4m3fn)
|
||||
if w2_sf.dtype != torch.float8_e4m3fn:
|
||||
w2_sf = w2_sf.to(torch.float8_e4m3fn)
|
||||
|
||||
# Global scales
|
||||
l1_gs_list = w13_gs.tolist()
|
||||
l2_gs_list = w2_gs.tolist()
|
||||
|
||||
# Free original weight tensors IMMEDIATELY.
|
||||
# We have views into the same memory (l1_fp4, l2_fp4), but the runner
|
||||
# will create its own copies in _ensure_stacked. Free the layer refs
|
||||
# now so the memory can be reclaimed when the views are no longer held.
|
||||
# NOTE: The modular kernel framework reads w1.shape[0] in its outer
|
||||
# apply() before delegating to our expert impl, so we can't set the
|
||||
# weights to None. Instead, replace with a shape-preserving dummy on CPU
|
||||
# to free GPU memory while keeping the shape metadata accessible.
|
||||
# Free the large weight tensors — they're now in the runner.
|
||||
# Keep the scale tensors (small) because the framework's warmup
|
||||
# and quant config construction needs them.
|
||||
layer.w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2,
|
||||
device='cpu', dtype=torch.uint8), requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts, hidden_size, intermediate_size // 2,
|
||||
device='cpu', dtype=torch.uint8), requires_grad=False)
|
||||
|
||||
# Create the CuTeDSL runner
|
||||
self._runner = CuTeDSLMoERunner(
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
top_k=self.topk,
|
||||
device=str(device),
|
||||
experts_start_idx=self.local_expert_offset,
|
||||
)
|
||||
# Pass stacked tensors in checkpoint format (E, N, K) — no copies needed
|
||||
self._runner.prepare_weights_from_stacked(
|
||||
l1_fp4, w13_sf, l1_gs_list,
|
||||
l2_fp4, w2_sf, l2_gs_list,
|
||||
)
|
||||
if self._swiglu_limit is not None:
|
||||
self._runner.set_swiglu_limit(float(self._swiglu_limit))
|
||||
|
||||
# Read swiglu_limit from the FusedMoE layer (set by DeepseekV4MoE)
|
||||
swiglu_limit = getattr(layer, 'swiglu_limit', None)
|
||||
if swiglu_limit is not None:
|
||||
self._swiglu_limit = swiglu_limit
|
||||
self._runner.set_swiglu_limit(float(swiglu_limit))
|
||||
|
||||
# Set initial activation global scales from checkpoint input_scale.
|
||||
# After undoing the inversion from convert_to_nvfp4_moe_kernel_format,
|
||||
# w13_input_scale_orig = amax / (6.0 * 448.0), which IS the activation
|
||||
# global scale that quantize_activation_nvfp4 expects.
|
||||
# The warmup step (compute_activation_global_scales) will override
|
||||
# this with an empirically computed value before the first inference.
|
||||
if w13_input_scale_orig is not None:
|
||||
# w13_input_scale_orig = amax / (6.0 * 448.0) = activation gs
|
||||
# Mean across experts (they should be similar)
|
||||
mean_l1_gs = float(w13_input_scale_orig.mean().item())
|
||||
if mean_l1_gs > 0:
|
||||
self._runner._l1_activation_global_scale = mean_l1_gs
|
||||
if w2_input_scale_orig is not None:
|
||||
mean_l2_gs = float(w2_input_scale_orig.mean().item())
|
||||
if mean_l2_gs > 0:
|
||||
self._runner._l2_activation_global_scale = mean_l2_gs
|
||||
|
||||
# Note: activation global scale warmup must be done after
|
||||
# process_weights_after_loading, before the first inference.
|
||||
# This is handled by the model's load_weights or a separate warmup step.
|
||||
|
||||
@property
|
||||
def runner(self) -> CuTeDSLMoERunner | None:
|
||||
return self._runner
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
# CuTeDSL requires CUDA SM100 (Blackwell)
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
# We handle SiLU + swiglu_limit internally
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
# Our runner handles activation quantization internally
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Our runner manages its own workspace internally (pre-allocated buffers)
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
# The output of the full 2-stage MoE pipeline is hidden_dim.
|
||||
# K comes from hidden_states.size(-1) (full BF16 dimension, not packed).
|
||||
output = (M, self.hidden_dim)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert self._runner is not None, (
|
||||
"CuTeDSL runner not initialized. "
|
||||
"Call process_weights_after_loading first."
|
||||
)
|
||||
|
||||
# Our runner expects topk_ids as global expert IDs.
|
||||
# The modular kernel framework may pass local IDs with expert_map.
|
||||
# We handle remapping internally via experts_start_idx.
|
||||
result = self._runner.run(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
# Copy result into output tensor
|
||||
output.copy_(result)
|
||||
@@ -1,535 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
nvfp4_moe_quant_config,
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_nvfp4_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
|
||||
kE2M1ToFloat_handle,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NvFp4MoeBackend(Enum):
|
||||
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
|
||||
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
|
||||
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
|
||||
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
|
||||
VLLM_CUTLASS = "VLLM_CUTLASS"
|
||||
MARLIN = "MARLIN"
|
||||
CUTEDSL = "CUTEDSL"
|
||||
EMULATION = "EMULATION"
|
||||
|
||||
|
||||
FLASHINFER_NVFP4_MOE_BACKENDS = [
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
||||
]
|
||||
|
||||
CUTEDSL_NVFP4_MOE_BACKENDS = [
|
||||
NvFp4MoeBackend.CUTEDSL,
|
||||
]
|
||||
|
||||
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
|
||||
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
}
|
||||
|
||||
|
||||
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
|
||||
# Checks whether `backend` supports quantizing with scaling factors
|
||||
# of all experts in Expert Parallel Mode when all experts are not
|
||||
# on the same rank.
|
||||
|
||||
return backend in FLASHINFER_NVFP4_MOE_BACKENDS or backend in CUTEDSL_NVFP4_MOE_BACKENDS
|
||||
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: NvFp4MoeBackend,
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
)
|
||||
|
||||
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
|
||||
return [
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return [FlashInferExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501
|
||||
FlashInferCuteDSLExperts,
|
||||
)
|
||||
|
||||
return [FlashInferCuteDSLExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED:
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
|
||||
FlashInferCuteDSLBatchedExperts,
|
||||
)
|
||||
|
||||
return [FlashInferCuteDSLBatchedExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.CUTEDSL:
|
||||
from vllm.model_executor.layers.fused_moe.experts.cutedsl_moe import ( # noqa: E501
|
||||
CuTeDSLMoEExperts,
|
||||
)
|
||||
|
||||
return [CuTeDSLMoEExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
|
||||
return [CutlassExpertsFp4]
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.experts.marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
return [MarlinExperts]
|
||||
elif backend == NvFp4MoeBackend.EMULATION:
|
||||
from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import (
|
||||
Nvfp4QuantizationEmulationTritonExperts,
|
||||
)
|
||||
|
||||
return [Nvfp4QuantizationEmulationTritonExperts]
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
|
||||
|
||||
|
||||
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
|
||||
"""Map user's MoEBackend to NvFp4MoeBackend."""
|
||||
mapping = {
|
||||
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
"cutedsl": NvFp4MoeBackend.CUTEDSL,
|
||||
"marlin": NvFp4MoeBackend.MARLIN,
|
||||
"emulation": NvFp4MoeBackend.EMULATION,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def select_nvfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
"""
|
||||
Select the primary NvFP4 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
|
||||
# NOTE: the kernels are selected in the following order.
|
||||
AVAILABLE_BACKENDS = [
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
||||
NvFp4MoeBackend.CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
NvFp4MoeBackend.MARLIN,
|
||||
NvFp4MoeBackend.EMULATION,
|
||||
]
|
||||
|
||||
use_batched = config.moe_parallel_config.use_batched_activation_format
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if use_batched
|
||||
else mk.FusedMoEActivationFormat.Standard
|
||||
)
|
||||
|
||||
def _make_log_backend(backend: NvFp4MoeBackend):
|
||||
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
|
||||
return (
|
||||
f"Using '{backend.value}' NvFp4 MoE backend out "
|
||||
f"of potential backends: {available_backend_strs}."
|
||||
)
|
||||
|
||||
def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
|
||||
if reason:
|
||||
return (
|
||||
f"NvFp4 MoE backend '{backend.value}' does not support the "
|
||||
f"deployment configuration since {reason}."
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"NvFp4 MoE backend '{backend.value}' does not support the "
|
||||
"deployment configuration."
|
||||
)
|
||||
|
||||
def _return_or_raise(
|
||||
backend: NvFp4MoeBackend,
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_nvfp4_backend(runner_backend)
|
||||
# For batched activation format, use batched variant if available.
|
||||
if (
|
||||
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL
|
||||
):
|
||||
requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
# If the user rejects FlashInfer remove those backends.
|
||||
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
|
||||
AVAILABLE_BACKENDS.remove(b)
|
||||
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try each.
|
||||
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason))
|
||||
|
||||
raise NotImplementedError(
|
||||
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
|
||||
"FlashInfer NVFP4 MoE backend supports the configuration."
|
||||
)
|
||||
|
||||
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
|
||||
backend = NvFp4MoeBackend.MARLIN
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason))
|
||||
|
||||
raise NotImplementedError(
|
||||
"No NvFp4 MoE backend supports the deployment configuration."
|
||||
)
|
||||
|
||||
|
||||
def convert_to_nvfp4_moe_kernel_format(
|
||||
nvfp4_backend: NvFp4MoeBackend,
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor | None,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a2_scale: torch.Tensor | None,
|
||||
is_act_and_mul: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
if nvfp4_backend == NvFp4MoeBackend.CUTEDSL:
|
||||
# CuTeDSL kernel handles weight conversion in its own
|
||||
# process_weights_after_loading. Pass through raw weights.
|
||||
# Compute inverse activation scales for the quant config.
|
||||
if a13_scale is not None:
|
||||
a13_scale = 1.0 / a13_scale
|
||||
if a2_scale is not None:
|
||||
a2_scale = 1.0 / a2_scale
|
||||
elif nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
a13_scale=a13_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
elif (
|
||||
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
|
||||
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
|
||||
):
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend=nvfp4_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
a13_scale=a13_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
a2_scale=a2_scale,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
|
||||
a13_scale = None
|
||||
a2_scale = None
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
) = prepare_nvfp4_moe_layer_for_marlin(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
elif nvfp4_backend == NvFp4MoeBackend.EMULATION:
|
||||
# Move the E2M1 lookup table to the device now, because
|
||||
# `.to(device)` is not allowed during CUDA graph capture.
|
||||
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(w13.device)
|
||||
|
||||
if a13_scale is None or a2_scale is None:
|
||||
raise ValueError(
|
||||
"Activation global scales should not be None, got"
|
||||
f" a13_scale={a13_scale}, a2_scale={a2_scale}"
|
||||
)
|
||||
|
||||
if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1:
|
||||
logger.warning_once(
|
||||
"In NVFP4 linear, the activation global scale for inputs are different"
|
||||
" for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using"
|
||||
" a13_scale = a13_scale.max() and a2_scale = a2_scale.max()."
|
||||
)
|
||||
|
||||
# 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py.
|
||||
# 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant
|
||||
# use the inverse scale directly (large global scale).
|
||||
# NOTE: Before this point, `a13_scale` and `a2_scale` are such that:
|
||||
# `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`,
|
||||
# and `global_scale[expert_id]` are small (~1e-4).
|
||||
# Taking the largest global scale likely results in overflowing the FP8 range
|
||||
# for other experts - other selection strategies may be used.
|
||||
a13_scale = 1.0 / a13_scale.max().to(torch.float32)
|
||||
a2_scale = 1.0 / a2_scale.max().to(torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
|
||||
|
||||
return (
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_quant_config(
|
||||
backend: NvFp4MoeBackend,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
if backend == NvFp4MoeBackend.MARLIN:
|
||||
return nvfp4_w4a16_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
elif backend == NvFp4MoeBackend.EMULATION:
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
a1_gscale=a13_scale,
|
||||
a2_gscale=a2_scale,
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
|
||||
# The expert's process_weights_after_loading will fuse activation
|
||||
# scales in-place. Since the quant config references the same tensor
|
||||
# as the registered parameter, EPLB rearrangement stays in sync.
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
a1_gscale=(1.0 / a13_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
# NOTE(rob): this is a hack until the MoE kernels
|
||||
# create their own quant configs. TRTLLM kernel
|
||||
# does not accept swizzled input quant scales.
|
||||
is_scale_swizzled=(
|
||||
backend
|
||||
not in (
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.CUTEDSL,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
|
||||
|
||||
# Create Experts.
|
||||
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
||||
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens is not None
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
else:
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# TODO(rob): update inplace logic to be part of the kernel.
|
||||
return kernel
|
||||
@@ -1,216 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, fields
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config, get_hash_factors, hash_factors
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
class IrOpPriorityConfig:
|
||||
"""
|
||||
Configuration for vLLM IR op priority for dispatching/lowering during the
|
||||
forward pass. Each member is a list of strings, which will be passed to
|
||||
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
|
||||
A single comma-separated string is accepted as well,
|
||||
|
||||
If specified manually, platform defaults will be appended to the lists.
|
||||
See KernelConfig.set_platform_defaults().
|
||||
"""
|
||||
|
||||
rms_norm: list[str] = Field(default_factory=list)
|
||||
"""Priority list for vllm.ir.ops.rms_norm"""
|
||||
|
||||
fused_add_rms_norm: list[str] = Field(default_factory=list)
|
||||
"""Priority list for vllm.ir.ops.fused_add_rms_norm"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Any future fields that don't affect compilation should be excluded.
|
||||
|
||||
Also, manually add IR op impl UUIDs to make sure they affect the compile cache.
|
||||
"""
|
||||
factors = get_hash_factors(self, set())
|
||||
|
||||
# Implementations are hidden from Dynamo,
|
||||
# so they don't show up in the traced files list.
|
||||
from vllm.ir.op import IrOp
|
||||
|
||||
assert "_impls" not in factors
|
||||
factors["_impls"] = {
|
||||
name: {
|
||||
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
|
||||
}
|
||||
for name, p in asdict(self).items() # type: ignore[call-overload]
|
||||
}
|
||||
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def _to_list_str(cls, value: str | list[str]):
|
||||
if isinstance(value, str):
|
||||
value = value.replace(" ", "").split(",")
|
||||
|
||||
assert all(isinstance(v, str) for v in value)
|
||||
return value
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_priority(self):
|
||||
"""
|
||||
Context manager to set the IR op priority for all op members.
|
||||
It also imports IR kernel implementations for the current platform
|
||||
to ensure all implementations are made available.
|
||||
"""
|
||||
from vllm.ir.op import IrOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.import_ir_kernels()
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
for field in fields(self): # type: ignore[arg-type]
|
||||
op_priority = getattr(self, field.name)
|
||||
assert op_priority is not None, (
|
||||
f"IR op priority for {field.name} must be set"
|
||||
)
|
||||
logger.debug(
|
||||
"Setting IR op priority for %s to %s", field.name, op_priority
|
||||
)
|
||||
ir_op = IrOp.registry[field.name]
|
||||
stack.enter_context(ir_op.set_priority(op_priority))
|
||||
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
def with_default(
|
||||
cls, default: list[str], /, **kwargs: list[str]
|
||||
) -> "IrOpPriorityConfig":
|
||||
"""
|
||||
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
|
||||
use the given default list.
|
||||
"""
|
||||
for field in fields(cls): # type: ignore[arg-type]
|
||||
if field.name not in kwargs:
|
||||
kwargs[field.name] = list(default)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
MoEBackend = Literal[
|
||||
"auto",
|
||||
"triton",
|
||||
"deep_gemm",
|
||||
"deep_gemm_mega_moe",
|
||||
"cutlass",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_cutedsl",
|
||||
"marlin",
|
||||
"humming",
|
||||
"triton_unfused",
|
||||
"aiter",
|
||||
"cutedsl",
|
||||
"emulation",
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class KernelConfig:
|
||||
"""Configuration for kernel selection and warmup behavior."""
|
||||
|
||||
ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig)
|
||||
"""
|
||||
vLLM IR op priority for dispatching/lowering during the forward pass.
|
||||
Platform defaults appended automatically during VllmConfig.__post_init__.
|
||||
"""
|
||||
|
||||
enable_flashinfer_autotune: bool = None # type: ignore[assignment]
|
||||
"""If True, run FlashInfer autotuning during kernel warmup."""
|
||||
|
||||
moe_backend: MoEBackend = "auto"
|
||||
"""Backend for MoE expert computation kernels. Available options:
|
||||
|
||||
- "auto": Automatically select the best backend based on model and hardware
|
||||
- "triton": Use Triton-based fused MoE kernels
|
||||
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)
|
||||
- "deep_gemm_mega_moe": Use DeepGEMM mega MoE kernels
|
||||
- "cutlass": Use vLLM CUTLASS kernels
|
||||
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels
|
||||
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
|
||||
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
|
||||
- "marlin": Use Marlin kernels (weight-only quantization)
|
||||
- "humming": Use Humming Mixed Precision kernels
|
||||
- "triton_unfused": Use Triton unfused MoE kernels
|
||||
- "aiter": Use AMD AITer kernels (ROCm only)
|
||||
- "emulation": use BF16/FP16 GEMM, dequantizing weights and
|
||||
running QDQ on activations.
|
||||
"""
|
||||
|
||||
@field_validator("moe_backend", mode="before")
|
||||
@classmethod
|
||||
def _normalize_moe_backend(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.lower().replace("-", "_")
|
||||
return value
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Any future fields that don't affect compilation should be excluded.
|
||||
"""
|
||||
ignored_factors = {
|
||||
"enable_flashinfer_autotune",
|
||||
"ir_op_priority", # handled separately below
|
||||
}
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
factors["ir_op_priority"] = self.ir_op_priority.compute_hash()
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialization is delayed."""
|
||||
if value is None:
|
||||
return value
|
||||
return handler(value)
|
||||
|
||||
def set_platform_defaults(self, vllm_config: "VllmConfig") -> None:
|
||||
"""Set platform-specific defaults for the kernel config."""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config)
|
||||
logger.debug(
|
||||
"Setting platform-specific IR op priority defaults: %s, user-defined: %s",
|
||||
platform_op_priority,
|
||||
self.ir_op_priority,
|
||||
)
|
||||
for op_name, op_priority in asdict(platform_op_priority).items():
|
||||
current_op_priority: list[str] = getattr(self.ir_op_priority, op_name)
|
||||
if current_op_priority is None:
|
||||
setattr(self.ir_op_priority, op_name, op_priority)
|
||||
else:
|
||||
# Append platform-specific priorities
|
||||
# Must be idempotent because vllm_config.set_platform_defaults() may be
|
||||
# called multiple times (due to VllmConfig.__post_init__ manual call).
|
||||
unique_op_priority = [
|
||||
op for op in op_priority if op not in current_op_priority
|
||||
]
|
||||
current_op_priority.extend(unique_op_priority)
|
||||
|
||||
logger.info(
|
||||
"Final IR op priority after setting platform defaults: %s",
|
||||
self.ir_op_priority,
|
||||
)
|
||||
@@ -1,356 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
CSA/HCA attention for Blackwell (SM100+).
|
||||
|
||||
Replaces vLLM's FlashMLA + fused CUDA kernels with our own KV cache-based
|
||||
attention pipeline. The previous version used raw KV for attention (no cache),
|
||||
which produced garbage during decode because the KV cache was never written.
|
||||
|
||||
Key changes from the broken version:
|
||||
1. fused_qnorm_rope_kv_insert_py NOW WRITES KV to the paged cache (fp8)
|
||||
2. full_sdpa_attention is replaced with cache-aware attention
|
||||
3. KV is quantized to fp8 with per-token scale, RoPE applied before caching
|
||||
|
||||
Architecture:
|
||||
- KV latent: (T, HD=512) single head, shared across 128 Q heads
|
||||
- KV Cache: fp8_e4m3 paged cache with per-token inverse scale
|
||||
- Attention: BF16 (NVFP4 too lossy for Q×K^T)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def apply_gptj_rope(x, cos, sin, nope_dim):
|
||||
out = x.clone()
|
||||
even = x[..., nope_dim:][..., 0::2]
|
||||
odd = x[..., nope_dim:][..., 1::2]
|
||||
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
|
||||
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
|
||||
return out
|
||||
|
||||
|
||||
def apply_inv_gptj_rope(x, cos, sin, nope_dim):
|
||||
out = x.clone()
|
||||
even = x[..., nope_dim:][..., 0::2]
|
||||
odd = x[..., nope_dim:][..., 1::2]
|
||||
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
|
||||
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
|
||||
return out
|
||||
|
||||
|
||||
# ── KV Cache Operations ──────────────────────────────────────────────
|
||||
|
||||
def kv_quantize_fp8(kv_bf16):
|
||||
"""BF16 KV → fp8_e4m3 with per-token inverse scale."""
|
||||
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
|
||||
scale = fp8_max / amax
|
||||
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
|
||||
inv_scale = (amax / fp8_max).to(torch.bfloat16)
|
||||
return kv_fp8, inv_scale
|
||||
|
||||
|
||||
def kv_dequantize_fp8(kv_fp8, inv_scale):
|
||||
"""fp8 KV → BF16."""
|
||||
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
|
||||
|
||||
|
||||
def paged_kv_write(kv_data, slot_mapping, cache, block_size):
|
||||
"""Write KV into paged cache.
|
||||
|
||||
kv_data: (T, D) tensor to write (fp8 or bf16)
|
||||
slot_mapping: (T,) slot indices
|
||||
cache: (num_blocks, block_size, D) cache tensor (may be uint8)
|
||||
"""
|
||||
# Handle dtype mismatch: cache is uint8, kv_data is fp8
|
||||
if cache.dtype == torch.uint8 and kv_data.dtype == torch.float8_e4m3fn:
|
||||
kv_to_write = kv_data.view(torch.uint8)
|
||||
else:
|
||||
kv_to_write = kv_data
|
||||
|
||||
# Vectorized write using advanced indexing
|
||||
block_indices = slot_mapping // block_size
|
||||
offsets = slot_mapping % block_size
|
||||
# Clamp to valid range (safety)
|
||||
valid = (block_indices < cache.shape[0]) & (offsets < cache.shape[1])
|
||||
if valid.all():
|
||||
cache[block_indices, offsets] = kv_to_write
|
||||
else:
|
||||
# Fall back to per-token for partial writes
|
||||
for t in range(kv_data.shape[0]):
|
||||
bi = block_indices[t].item()
|
||||
oi = offsets[t].item()
|
||||
if bi < cache.shape[0] and oi < cache.shape[1]:
|
||||
cache[bi, oi] = kv_to_write[t]
|
||||
|
||||
|
||||
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
|
||||
"""Read KV from paged cache. Returns fp8 or uint8.
|
||||
|
||||
Vectorized version — uses advanced indexing instead of Python for loop.
|
||||
"""
|
||||
device = cache.device
|
||||
# Compute block indices and offsets
|
||||
slots = slot_mapping # (num_tokens,)
|
||||
block_indices = slots // block_size
|
||||
offsets = slots % block_size
|
||||
|
||||
# Advanced indexing: cache[block_indices, offsets] -> (num_tokens, head_dim)
|
||||
kv = cache[block_indices, offsets]
|
||||
|
||||
# If cache is uint8, reinterpret as fp8
|
||||
if cache.dtype == torch.uint8:
|
||||
kv = kv.view(torch.float8_e4m3fn)
|
||||
return kv
|
||||
|
||||
|
||||
# ── Attention ─────────────────────────────────────────────────────────
|
||||
|
||||
def causal_prefill_attention(q, kv, scale):
|
||||
"""Full causal self-attention for prefill. q: (T, NH, HD), kv: (T, HD)."""
|
||||
T, NH, HD = q.shape
|
||||
q_t = q.permute(1, 0, 2) # (NH, T, HD)
|
||||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T, HD)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
|
||||
return out.permute(1, 0, 2) # (T, NH, HD)
|
||||
|
||||
|
||||
def decode_attention(q, kv, scale):
|
||||
"""Decode attention: 1 query vs N cached KVs.
|
||||
|
||||
q: (1, NH, HD) — single decode token
|
||||
kv: (N, HD) — all cached KV (already with RoPE)
|
||||
"""
|
||||
NH = q.shape[1]
|
||||
HD = q.shape[2]
|
||||
q_t = q.permute(1, 0, 2) # (NH, 1, HD)
|
||||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, N, HD)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
|
||||
return out.permute(1, 0, 2) # (1, NH, HD)
|
||||
|
||||
|
||||
# ── Fused Q norm + RoPE + KV cache write ─────────────────────────────
|
||||
|
||||
def fused_qnorm_rope_kv_insert_py(
|
||||
q, # (T, num_heads, head_dim) — modified in-place
|
||||
kv, # (T, head_dim) — not modified
|
||||
swa_kv_cache_2d, # paged KV cache (2D view)
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
nope_dim,
|
||||
rope_dim,
|
||||
) -> None:
|
||||
"""Pure PyTorch: RoPE on Q only.
|
||||
|
||||
Q is already normed (by fused_q_kv_rmsnorm), so we only apply RoPE.
|
||||
The original CUDA kernel also does KV cache insert, but we handle that
|
||||
separately in blackwell_attention_kv_write.
|
||||
"""
|
||||
T = q.shape[0]
|
||||
if T == 0:
|
||||
return
|
||||
|
||||
# GPT-J RoPE on Q only (Q is already normed)
|
||||
half = rope_dim // 2
|
||||
cos_q = cos_sin_cache[positions, :half].unsqueeze(1).to(q.dtype)
|
||||
sin_q = cos_sin_cache[positions, half:2*half].unsqueeze(1).to(q.dtype)
|
||||
q_rope = q[:, :, nope_dim:].clone()
|
||||
q[:, :, nope_dim:][:, :, 0::2] = q_rope[:, :, 0::2] * cos_q - q_rope[:, :, 1::2] * sin_q
|
||||
q[:, :, nope_dim:][:, :, 1::2] = q_rope[:, :, 0::2] * sin_q + q_rope[:, :, 1::2] * cos_q
|
||||
|
||||
|
||||
def blackwell_attention_kv_write(
|
||||
kv, # (T, HD) kv_normed — NOT RoPE'd yet
|
||||
positions, # (T,) absolute positions
|
||||
swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache
|
||||
swa_inv_scale, # (max_slots, 1) per-token inv scale
|
||||
slot_mapping, # (T,) slot indices
|
||||
block_size, # tokens per block
|
||||
cos_sin_cache, # (max_pos, rope_dim) for RoPE
|
||||
nope_dim, # 448
|
||||
rope_dim, # 64
|
||||
) -> None:
|
||||
"""Write KV to paged cache: apply RoPE → fp8 quantize → insert.
|
||||
|
||||
This is the function that vLLM's Blackwell path was missing.
|
||||
Without this, the KV cache is never written, and decode attention
|
||||
produces garbage because it can't access prior tokens' KV.
|
||||
"""
|
||||
T = kv.shape[0]
|
||||
if T == 0:
|
||||
return
|
||||
|
||||
# Apply GPT-J RoPE to KV
|
||||
half = rope_dim // 2
|
||||
cos = cos_sin_cache[positions, :half].to(kv.dtype)
|
||||
sin = cos_sin_cache[positions, half:2 * half].to(kv.dtype)
|
||||
# Must use original values for both even and odd before modifying
|
||||
kv_rope_nope = kv[:, nope_dim:].clone()
|
||||
even = kv_rope_nope[:, 0::2]
|
||||
odd = kv_rope_nope[:, 1::2]
|
||||
new_even = even * cos - odd * sin
|
||||
new_odd = even * sin + odd * cos
|
||||
kv_rope = kv.clone()
|
||||
kv_rope[:, nope_dim:][:, 0::2] = new_even
|
||||
kv_rope[:, nope_dim:][:, 1::2] = new_odd
|
||||
|
||||
# Quantize to fp8
|
||||
kv_fp8, inv_scale = kv_quantize_fp8(kv_rope)
|
||||
|
||||
# Write to paged cache
|
||||
paged_kv_write(kv_fp8, slot_mapping, swa_kv_cache, block_size)
|
||||
|
||||
# Write inv_scale to flat cache
|
||||
for t in range(T):
|
||||
slot = slot_mapping[t].item()
|
||||
swa_inv_scale[slot] = inv_scale[t]
|
||||
|
||||
|
||||
def blackwell_attention_decode(
|
||||
q, # (1, NH, HD) single decode query with RoPE
|
||||
positions, # (1,) absolute position
|
||||
swa_kv_cache, # (num_blocks, block_size, HD) fp8 SWA cache (uint8)
|
||||
swa_inv_scale, # (max_slots, 1) per-token inv scale
|
||||
slot_mapping, # (1,) slot for the new token (already written)
|
||||
block_size, # tokens per block
|
||||
scale, # 1/sqrt(HD)
|
||||
window_size, # 128
|
||||
swa_indices=None, # (num_decode_tokens, window_size) pre-computed paged indices
|
||||
swa_lens=None, # (num_decode_tokens,) number of valid indices per token
|
||||
decode_token_idx=0, # which decode token this is
|
||||
) -> torch.Tensor:
|
||||
"""Decode attention: read cached KV using paged indices, attend.
|
||||
|
||||
Uses pre-computed swa_indices from vLLM's metadata for correct paged access.
|
||||
Returns: (1, NH, HD) attention output.
|
||||
"""
|
||||
NH = q.shape[1]
|
||||
HD = q.shape[2]
|
||||
device = q.device
|
||||
|
||||
if swa_indices is not None and swa_lens is not None:
|
||||
# Use pre-computed paged indices from vLLM
|
||||
num_valid = swa_lens[decode_token_idx].item()
|
||||
indices = swa_indices[decode_token_idx, :num_valid]
|
||||
block_indices = indices // block_size
|
||||
offsets = indices % block_size
|
||||
kv_cached_raw = swa_kv_cache[block_indices, offsets]
|
||||
if swa_kv_cache.dtype == torch.uint8:
|
||||
kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn)
|
||||
# Dequantize with per-token inverse scale
|
||||
inv_scales = swa_inv_scale[indices]
|
||||
kv_cached = kv_dequantize_fp8(kv_cached_raw, inv_scales)
|
||||
else:
|
||||
# Fallback: sequential slot access
|
||||
pos = positions[0].item()
|
||||
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device)
|
||||
kv_cached_raw = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD)
|
||||
kv_inv_scales = swa_inv_scale[all_slots]
|
||||
kv_cached = kv_dequantize_fp8(kv_cached_raw, kv_inv_scales)
|
||||
window_start = max(0, pos - window_size + 1)
|
||||
kv_cached = kv_cached[window_start:]
|
||||
|
||||
q_t = q.permute(1, 0, 2)
|
||||
kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
|
||||
return out.permute(1, 0, 2)
|
||||
|
||||
|
||||
def full_sdpa_attention(
|
||||
q: torch.Tensor, # (T, NH, HD) with RoPE
|
||||
kv: torch.Tensor, # (T, HD) KV latent
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Full causal self-attention for PREFILL only.
|
||||
|
||||
DEPRECATED: Use causal_prefill_attention instead.
|
||||
Kept for backward compatibility with the existing vLLM patch.
|
||||
"""
|
||||
return causal_prefill_attention(q, kv, scale)
|
||||
|
||||
|
||||
# ── CSA/HCA Decode Attention ─────────────────────────────────────────
|
||||
|
||||
def blackwell_csa_decode_attention(
|
||||
q, # (num_decode_tokens, NH, HD) with RoPE
|
||||
positions, # (num_decode_tokens,)
|
||||
swa_kv_cache, # (num_blocks, block_size, D) fp8 SWA cache
|
||||
swa_inv_scale, # (max_slots, 1) per-token inv scale
|
||||
swa_metadata, # DeepseekSparseSWAMetadata
|
||||
flashmla_metadata, # FlashMLASparseMetadata (for topk_indices)
|
||||
compressed_kv_cache, # (num_blocks, block_size, D) compressed KV cache
|
||||
compress_ratio, # 4 or 128
|
||||
scale, # 1/sqrt(HD)
|
||||
window_size, # 128
|
||||
nope_dim, # 448
|
||||
rope_dim, # 64
|
||||
cos_sin_cache, # (max_pos, rope_dim)
|
||||
attn_sink, # (NH,) sink weights
|
||||
max_model_len, # max sequence length
|
||||
) -> torch.Tensor:
|
||||
"""CSA/HCA decode: sparse attention on compressed KV + SWA + sink merge.
|
||||
|
||||
For each decode token:
|
||||
1. Get topk_indices from the indexer (already computed)
|
||||
2. Gather compressed KV at topk positions
|
||||
3. Sparse attention on gathered KV
|
||||
4. SWA attention from paged cache
|
||||
5. Merge with sink weights
|
||||
"""
|
||||
num_tokens, NH, HD = q.shape
|
||||
device = q.device
|
||||
block_size = swa_metadata.block_size
|
||||
|
||||
output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Get topk indices from the indexer
|
||||
num_decodes = swa_metadata.num_decodes
|
||||
is_valid = swa_metadata.is_valid_token[:num_tokens]
|
||||
|
||||
if compress_ratio == 4:
|
||||
# C4A: topk indices from indexer buffer
|
||||
# These are computed by the indexer during this forward pass
|
||||
# For now, we need to get them from the metadata
|
||||
# The indexer fills topk_indices_buffer
|
||||
pass
|
||||
# C128A: pre-computed in the metadata
|
||||
|
||||
# For now, fall back to SWA-only for CSA/HCA decode
|
||||
# The sparse component will be added once we verify the SWA path works
|
||||
for t in range(num_tokens):
|
||||
output[t] = blackwell_attention_decode(
|
||||
q[t:t+1], positions[t:t+1],
|
||||
swa_kv_cache, swa_inv_scale,
|
||||
swa_metadata.slot_mapping[t:t+1],
|
||||
block_size, scale, window_size,
|
||||
).squeeze(0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def csa_sparse_prefill_attention(
|
||||
q, # (num_prefills, NH, HD) with RoPE
|
||||
kv_rope, # (num_prefills, HD) KV with RoPE
|
||||
compressed_kv_cache, # compressed KV cache
|
||||
flashmla_metadata, # FlashMLASparseMetadata
|
||||
swa_metadata, # DeepseekSparseSWAMetadata
|
||||
compress_ratio, # 4 or 128
|
||||
scale, # 1/sqrt(HD)
|
||||
window_size, # 128
|
||||
nope_dim, # 448
|
||||
rope_dim, # 64
|
||||
cos_sin_cache, # (max_pos, rope_dim)
|
||||
attn_sink, # (NH,) sink weights
|
||||
max_model_len, # max sequence length
|
||||
) -> torch.Tensor:
|
||||
"""CSA/HCA prefill: sparse + SWA attention.
|
||||
|
||||
For now, falls back to full causal attention (which is correct
|
||||
but not optimal for long sequences).
|
||||
"""
|
||||
# Full causal attention is always correct for prefill
|
||||
return causal_prefill_attention(q, kv_rope, scale)
|
||||
@@ -1,455 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Check at module load time if we're on Blackwell
|
||||
_IS_BLACKWELL = False
|
||||
try:
|
||||
_cap = current_platform.get_device_capability()
|
||||
if _cap is not None and _cap.major >= 10:
|
||||
_IS_BLACKWELL = True
|
||||
except Exception:
|
||||
pass
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
|
||||
_fused_kv_compress_norm_rope_insert_indexer_attn,
|
||||
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
|
||||
_fused_kv_compress_norm_rope_insert_sparse_attn,
|
||||
)
|
||||
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
|
||||
MXFP4_BLOCK_SIZE,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
KVCacheSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowMLASpec,
|
||||
)
|
||||
|
||||
|
||||
class CompressorBackend(AttentionBackend):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CompressorBackend"
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [512, 1024]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CompressorMetadataBuilder"]:
|
||||
return CompressorMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
assert num_kv_heads == 1
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
if include_num_layers_dimension:
|
||||
return (0, 1, 2, 3)
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressorMetadata:
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_size: int
|
||||
|
||||
token_to_req_indices: torch.Tensor | None = None # [num_tokens]
|
||||
|
||||
|
||||
class CompressorMetadataBuilder(AttentionMetadataBuilder):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
|
||||
mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec)
|
||||
self.block_size = mla_spec.block_size
|
||||
|
||||
self.token_to_req_indices = torch.zeros(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CompressorMetadata:
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory()
|
||||
token_to_req_indices = self.token_to_req_indices[: x.shape[0]]
|
||||
token_to_req_indices.copy_(x, non_blocking=True)
|
||||
return CompressorMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor.clamp_(min=0),
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_size=self.block_size,
|
||||
token_to_req_indices=token_to_req_indices,
|
||||
)
|
||||
|
||||
|
||||
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
state_dim: int,
|
||||
dtype: torch.dtype,
|
||||
compress_ratio: int,
|
||||
prefix: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.state_dim = state_dim
|
||||
self.dtype = dtype
|
||||
self.prefix = prefix
|
||||
self.kv_cache = torch.tensor([])
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
assert self.dtype == torch.float32
|
||||
assert compress_ratio in [4, 128]
|
||||
coff = 1 + (compress_ratio == 4)
|
||||
self.sliding_window = coff * compress_ratio
|
||||
# Block size is constrained by tensor sharing between compressor states
|
||||
# and KV blocks. Since compressor states share the same physical tensor
|
||||
# as KV blocks, they must use the same page size.
|
||||
# The KV block shape [256//4, head_dim] = [64, 584] determines:
|
||||
# - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4
|
||||
# - C128 compressor block shape [8, 512*2*4] -> block_size = 8
|
||||
# TODO(yifan): make block size automatically determined and configurable.
|
||||
if compress_ratio == 4:
|
||||
self.block_size = 4
|
||||
elif compress_ratio == 128:
|
||||
self.block_size = 8
|
||||
else:
|
||||
raise ValueError(f"Invalid compress ratio: {compress_ratio}")
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return SlidingWindowMLASpec( # only has one vector instead of K + V
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.state_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
||||
)
|
||||
|
||||
def forward(self): ...
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return CompressorBackend
|
||||
|
||||
|
||||
class DeepseekCompressor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
compress_ratio: int,
|
||||
hidden_size: int,
|
||||
head_dim: int,
|
||||
rotate: bool = False,
|
||||
prefix: str = "",
|
||||
k_cache_prefix="",
|
||||
use_fp4_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.compress_ratio = compress_ratio
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.rotate = rotate
|
||||
self.prefix = prefix
|
||||
self.k_cache_prefix = k_cache_prefix
|
||||
self.use_fp4_cache = use_fp4_cache
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.rope_head_dim = config.qk_rope_head_dim
|
||||
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
self.device = current_platform.device_type
|
||||
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.overlap = compress_ratio == 4
|
||||
self.coff = 1 + self.overlap
|
||||
|
||||
state_dtype = torch.float32
|
||||
self.ape = nn.Parameter(
|
||||
torch.empty(
|
||||
(compress_ratio, self.coff * self.head_dim),
|
||||
dtype=state_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.fused_wkv_wgate = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.coff * self.head_dim, self.coff * self.head_dim],
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
disable_tp=True,
|
||||
prefix=f"{prefix}.fused_wkv_wgate",
|
||||
)
|
||||
self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)
|
||||
|
||||
self.state_cache = CompressorStateCache(
|
||||
state_dim=2 * self.coff * self.head_dim, # kv_state + score_state
|
||||
dtype=state_dtype,
|
||||
compress_ratio=compress_ratio,
|
||||
prefix=f"{prefix}.state_cache",
|
||||
)
|
||||
|
||||
# Save reference to static_forward_context for forward-time KV cache lookup.
|
||||
# get_current_vllm_config() is only available during __init__, not forward.
|
||||
self._static_forward_context = (
|
||||
vllm_config.compilation_config.static_forward_context
|
||||
)
|
||||
|
||||
if self.head_dim == 512:
|
||||
assert not use_fp4_cache, (
|
||||
"MXFP4 cache is only supported for indexer (head=128)"
|
||||
)
|
||||
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
|
||||
self._quant_block = 64
|
||||
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
|
||||
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
|
||||
self._num_warps = 4
|
||||
elif self.head_dim == 128:
|
||||
if use_fp4_cache:
|
||||
self._fused_kernel = (
|
||||
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
|
||||
)
|
||||
self._quant_block = MXFP4_BLOCK_SIZE
|
||||
self._token_stride = self.head_dim // 2
|
||||
self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE
|
||||
else:
|
||||
self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
|
||||
self._quant_block = 128
|
||||
self._token_stride = self.head_dim
|
||||
self._scale_dim = 4 # single float32 scale
|
||||
self._num_warps = 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported head_dim for fused quant+cache: {self.head_dim}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
# [num_tokens, 2 * self.coff * self.head_dim]
|
||||
kv_score: torch.Tensor,
|
||||
# [num_tokens]
|
||||
positions: torch.Tensor,
|
||||
rotary_emb,
|
||||
) -> None:
|
||||
# Each of shape [num_tokens, coff * self.head_dim]
|
||||
# input bf16, output are fp32
|
||||
kv, score = kv_score.split(
|
||||
[self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Get the metadata and handle dummy profiling run.
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return
|
||||
|
||||
state_metadata = cast(
|
||||
CompressorMetadata, attn_metadata[self.state_cache.prefix]
|
||||
)
|
||||
token_to_req_indices = state_metadata.token_to_req_indices
|
||||
slot_mapping = state_metadata.slot_mapping
|
||||
num_actual = slot_mapping.shape[0]
|
||||
block_table = state_metadata.block_table
|
||||
block_size = state_metadata.block_size
|
||||
|
||||
# [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
|
||||
state_cache = self.state_cache.kv_cache
|
||||
# kv_state stored in first half, score_state stored in second half
|
||||
state_width = state_cache.shape[-1] // 2
|
||||
pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False}
|
||||
|
||||
# Store the KV and score (with fused APE addition) in the state.
|
||||
# NOTE: PDL is disabled — both this kernel and _fused_kernel below
|
||||
# depend on preceding kernel outputs (kv/score from the cublas GEMM;
|
||||
# state_cache from this kernel) but neither emits/waits on PDL grid
|
||||
# dependency primitives, so launch_pdl=True caused a read-after-write
|
||||
# race and non-deterministic output.
|
||||
_save_partial_states_kernel[(num_actual,)](
|
||||
kv,
|
||||
kv.stride(0),
|
||||
score,
|
||||
score.stride(0),
|
||||
self.ape,
|
||||
self.ape.stride(0),
|
||||
positions,
|
||||
state_cache,
|
||||
state_cache.stride(0),
|
||||
state_cache.stride(1),
|
||||
slot_mapping,
|
||||
block_size,
|
||||
HEAD_SIZE=kv.shape[-1],
|
||||
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
|
||||
STATE_WIDTH=state_width,
|
||||
COMPRESS_RATIO=self.compress_ratio,
|
||||
**pdl_kwargs,
|
||||
)
|
||||
|
||||
# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
|
||||
# RoPE requirements (kernel applies forward GPT-J style rotation):
|
||||
# - is_neox_style=False (interleaved pairs, NOT split-half)
|
||||
# - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
|
||||
# second half sin (per-pair, length rope_head_dim // 2 each)
|
||||
# - applied to LAST rope_head_dim elements of head_dim
|
||||
# - position used: (positions // compress_ratio) * compress_ratio
|
||||
|
||||
# On Blackwell (SM100+), skip the fused kernel because:
|
||||
# 1. The fused kernel does attention using FlashMLA which doesn't work on SM100
|
||||
# 2. Our Blackwell attention path handles everything separately
|
||||
# Instead, we just save the state (done above) and let the attention
|
||||
# path handle compression + RoPE + cache write + attention.
|
||||
if _IS_BLACKWELL:
|
||||
# Blackwell: state is already saved, skip fused kernel
|
||||
return
|
||||
|
||||
cos_sin_cache = rotary_emb.cos_sin_cache
|
||||
k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
|
||||
kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache
|
||||
|
||||
self._fused_kernel[(num_actual,)](
|
||||
# state cache
|
||||
state_cache,
|
||||
state_cache.stride(0),
|
||||
state_cache.stride(1),
|
||||
# metadata
|
||||
token_to_req_indices,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
block_table.stride(0),
|
||||
block_size,
|
||||
# RMSNorm
|
||||
self.norm.weight,
|
||||
self.rms_norm_eps,
|
||||
# RoPE
|
||||
cos_sin_cache,
|
||||
cos_sin_cache.stride(0),
|
||||
# KV cache
|
||||
kv_cache,
|
||||
k_cache_metadata.slot_mapping,
|
||||
kv_cache.shape[1], # paged KV cache block size (tokens per block)
|
||||
# constexprs
|
||||
HEAD_SIZE=self.head_dim,
|
||||
TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim),
|
||||
STATE_WIDTH=state_width,
|
||||
COMPRESS_RATIO=self.compress_ratio,
|
||||
OVERLAP=self.overlap,
|
||||
ROPE_HEAD_DIM=self.rope_head_dim,
|
||||
FP8_MAX=448.0,
|
||||
QUANT_BLOCK=self._quant_block,
|
||||
TOKEN_STRIDE=self._token_stride,
|
||||
SCALE_DIM=self._scale_dim,
|
||||
KV_BLOCK_STRIDE=kv_cache.stride(0),
|
||||
num_warps=self._num_warps,
|
||||
**pdl_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _save_partial_states_kernel(
|
||||
kv_ptr,
|
||||
kv_stride,
|
||||
score_ptr,
|
||||
score_stride,
|
||||
ape_ptr,
|
||||
ape_stride,
|
||||
positions_ptr,
|
||||
state_cache_ptr,
|
||||
state_cache_stride0,
|
||||
state_cache_stride1,
|
||||
slot_mapping_ptr,
|
||||
block_size,
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
TRITON_BLOCK_SIZE: tl.constexpr,
|
||||
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
|
||||
STATE_WIDTH: tl.constexpr,
|
||||
COMPRESS_RATIO: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
slot_id = tl.load(slot_mapping_ptr + token_idx)
|
||||
|
||||
# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
|
||||
# by vLLM). During CUDA graph replay the batch may contain padding
|
||||
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
|
||||
# illegal memory access.
|
||||
if slot_id < 0:
|
||||
return
|
||||
|
||||
block_idx = slot_id // block_size
|
||||
pos_in_block = slot_id % block_size
|
||||
base_ptr = (
|
||||
state_cache_ptr
|
||||
+ block_idx * state_cache_stride0
|
||||
+ pos_in_block * state_cache_stride1
|
||||
)
|
||||
|
||||
block = tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
mask = block < HEAD_SIZE
|
||||
|
||||
kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
|
||||
tl.store(base_ptr + block, kv, mask=mask)
|
||||
|
||||
# Fused: score += ape[position % compress_ratio]
|
||||
position = tl.load(positions_ptr + token_idx)
|
||||
ape_row = position % COMPRESS_RATIO
|
||||
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
|
||||
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
|
||||
tl.store(
|
||||
base_ptr + STATE_WIDTH + block,
|
||||
score + ape,
|
||||
mask=mask,
|
||||
)
|
||||
@@ -1,195 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Patched MHC layer — pure PyTorch, no TileLang.
|
||||
# Replaces the TileLang-based mhc.py from the vLLM nightly image.
|
||||
# The original imports tilelang at the top level and JIT-compiles kernels
|
||||
# which don't work correctly on Blackwell (SM100).
|
||||
|
||||
import torch
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
# ── Pure PyTorch MHC implementations ──────────────────────────────────
|
||||
|
||||
def mhc_pre(
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert residual.dtype == torch.bfloat16
|
||||
assert fn.dtype == torch.float32
|
||||
assert hc_scale.dtype == torch.float32
|
||||
assert hc_base.dtype == torch.float32
|
||||
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
hc_mult2 = hc_mult * hc_mult
|
||||
hc_mult3 = hc_mult * 2 + hc_mult2
|
||||
hc_hidden_size = hc_mult * hidden_size
|
||||
outer_shape = residual.shape[:-2]
|
||||
|
||||
residual_flat = residual.view(-1, hc_mult, hidden_size)
|
||||
num_tokens = residual_flat.shape[0]
|
||||
|
||||
x = residual_flat.view(num_tokens, hc_hidden_size).to(torch.float32)
|
||||
mixes = torch.matmul(x, fn.t())
|
||||
sqrsum = x.square().sum(dim=-1, keepdim=True)
|
||||
mixes = mixes * torch.rsqrt(sqrsum / hc_hidden_size + rms_eps)
|
||||
|
||||
pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult]
|
||||
pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps
|
||||
|
||||
post_logits = mixes[:, hc_mult:2 * hc_mult] * hc_scale[1] + hc_base[hc_mult:2 * hc_mult]
|
||||
post_mix = torch.sigmoid(post_logits) * hc_post_mult_value
|
||||
|
||||
comb_logits = (mixes[:, 2 * hc_mult:]
|
||||
.view(num_tokens, hc_mult, hc_mult)
|
||||
* hc_scale[2]
|
||||
+ hc_base[2 * hc_mult:].view(1, hc_mult, hc_mult))
|
||||
comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
|
||||
for _ in range(sinkhorn_repeat - 1):
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps)
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
|
||||
|
||||
layer_input = torch.sum(
|
||||
pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1
|
||||
).to(torch.bfloat16)
|
||||
|
||||
return (
|
||||
post_mix.view(*outer_shape, hc_mult, 1),
|
||||
comb_mix.view(*outer_shape, hc_mult, hc_mult),
|
||||
layer_input.view(*outer_shape, hidden_size),
|
||||
)
|
||||
|
||||
|
||||
def _mhc_pre_fake(
|
||||
residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps,
|
||||
hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits=1,
|
||||
):
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
outer_shape = residual.shape[:-2]
|
||||
return (
|
||||
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
|
||||
)
|
||||
|
||||
|
||||
def mhc_post(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
mixed_residual = torch.einsum(
|
||||
"...ij,...ih->...jh",
|
||||
comb_res_mix.to(torch.float32),
|
||||
residual.to(torch.float32),
|
||||
)
|
||||
post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32)
|
||||
return (mixed_residual + post_term).to(residual.dtype)
|
||||
|
||||
|
||||
def _mhc_post_fake(x, residual, post_layer_mix, comb_res_mix):
|
||||
return torch.empty_like(residual)
|
||||
|
||||
|
||||
def mhc_fused_post_pre(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
tile_n: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
new_residual = mhc_post(x, residual, post_layer_mix, comb_res_mix)
|
||||
post_mix, res_mix, layer_input = mhc_pre(
|
||||
new_residual, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits,
|
||||
)
|
||||
return new_residual, post_mix, res_mix, layer_input
|
||||
|
||||
|
||||
def _mhc_fused_post_pre_fake(
|
||||
x, residual, post_layer_mix, comb_res_mix, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value,
|
||||
sinkhorn_repeat, n_splits=1, tile_n=1,
|
||||
):
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
outer_shape = residual.shape[:-2]
|
||||
return (
|
||||
torch.empty_like(residual),
|
||||
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
|
||||
)
|
||||
|
||||
|
||||
def _hc_head_fused_kernel(
|
||||
hs_flat: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
hidden_size: int,
|
||||
rms_eps: float,
|
||||
hc_eps: float,
|
||||
hc_mult: int,
|
||||
) -> None:
|
||||
if hs_flat.shape[0] == 0:
|
||||
return
|
||||
x_flat = hs_flat.reshape(hs_flat.shape[0], hc_mult * hidden_size).to(torch.float32)
|
||||
mixes = torch.matmul(x_flat, fn.t())
|
||||
sqrsum = x_flat.square().sum(dim=-1, keepdim=True)
|
||||
rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps)
|
||||
pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps
|
||||
result = torch.sum(pre_mix.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1).to(out.dtype)
|
||||
out.copy_(result)
|
||||
|
||||
|
||||
# ── Register as torch custom ops ──────────────────────────────────────
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_pre",
|
||||
op_func=mhc_pre,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_pre_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_post",
|
||||
op_func=mhc_post,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_post_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_fused_post_pre",
|
||||
op_func=mhc_fused_post_pre,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_fused_post_pre_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="hc_head_fused_kernel",
|
||||
op_func=_hc_head_fused_kernel,
|
||||
mutates_args=["out"],
|
||||
)
|
||||
@@ -1,48 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Patch DeepseekCompressor cache for Blackwell: remove FlashMLA alignment."""
|
||||
import sys
|
||||
|
||||
def patch(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "CLAWMINE_PATCH_COMPRESSOR" in content:
|
||||
print("Already patched, skipping")
|
||||
return
|
||||
|
||||
old = """ return SlidingWindowMLASpec( # only has one vector instead of K + V
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.state_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
||||
)"""
|
||||
|
||||
new = """ # CLAWMINE_PATCH_COMPRESSOR: No FlashMLA alignment on Blackwell
|
||||
from vllm.platforms import current_platform
|
||||
_is_blackwell = (
|
||||
current_platform.get_device_capability() is not None
|
||||
and current_platform.get_device_capability().major >= 10
|
||||
)
|
||||
return SlidingWindowMLASpec( # only has one vector instead of K + V
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.state_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
alignment=None if _is_blackwell else 576,
|
||||
)"""
|
||||
|
||||
if old not in content:
|
||||
print("ERROR: Could not find the code to patch in " + path)
|
||||
sys.exit(1)
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched DeepseekCompressor for Blackwell")
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch(sys.argv[1])
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Patch _allocate_kv_cache_tensors to print the layer name mismatch."""
|
||||
import sys
|
||||
|
||||
def patch(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "CLAWMINE_DEBUG_LAYERS" in content:
|
||||
print("Already patched, skipping")
|
||||
return
|
||||
|
||||
old = """ assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)"""
|
||||
|
||||
new = """ # CLAWMINE_DEBUG_LAYERS: print mismatch instead of asserting
|
||||
missing = layer_names - set(kv_cache_raw_tensors.keys())
|
||||
extra = set(kv_cache_raw_tensors.keys()) - layer_names
|
||||
if missing or extra:
|
||||
print(f"CLAWMINE DEBUG: missing layers ({len(missing)}): {sorted(missing)[:20]}")
|
||||
print(f"CLAWMINE DEBUG: extra layers ({len(extra)}): {sorted(extra)[:20]}")
|
||||
print(f"CLAWMINE DEBUG: expected ({len(layer_names)}), got ({len(kv_cache_raw_tensors.keys())})")
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)"""
|
||||
|
||||
if old not in content:
|
||||
print("ERROR: Could not find the code to patch")
|
||||
sys.exit(1)
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched gpu_model_runner.py for debug layer names")
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch(sys.argv[1])
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Patch DeepseekV4IndexerCache on Blackwell: remove FlashMLA alignment.
|
||||
|
||||
Same as patch_swa_cache but for the indexer cache class.
|
||||
"""
|
||||
import sys
|
||||
|
||||
def patch(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "CLAWMINE_PATCH_INDEXER_CACHE" in content:
|
||||
print("Already patched, skipping")
|
||||
return
|
||||
|
||||
# Patch the indexer cache's get_kv_cache_spec to remove FlashMLA alignment on Blackwell
|
||||
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# head_dim already carries the fp8 scale padding
|
||||
# compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
|
||||
return MLAAttentionSpec(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
compress_ratio=self.compress_ratio,
|
||||
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
|
||||
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
|
||||
alignment=576,
|
||||
)"""
|
||||
|
||||
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# CLAWMINE_PATCH_INDEXER_CACHE: No FlashMLA alignment on Blackwell
|
||||
from vllm.platforms import current_platform
|
||||
_is_blackwell = (
|
||||
current_platform.get_device_capability() is not None
|
||||
and current_platform.get_device_capability().major >= 10
|
||||
)
|
||||
return MLAAttentionSpec(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
compress_ratio=self.compress_ratio,
|
||||
alignment=None if _is_blackwell else 576,
|
||||
)"""
|
||||
|
||||
if old not in content:
|
||||
print("ERROR: Could not find the code to patch in " + path)
|
||||
sys.exit(1)
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched DeepseekV4IndexerCache for Blackwell")
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch(sys.argv[1])
|
||||
@@ -1,135 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Patch vLLM kv_cache_utils.py to handle DeepseekV4 SWA page sizes.
|
||||
|
||||
DeepseekV4 has three cache types:
|
||||
- C128A (HCA): compress_ratio=128, very small page size
|
||||
- C4A (CSA): compress_ratio=4, medium page size
|
||||
- SWA: compress_ratio=1, large page size
|
||||
|
||||
The upstream code assumes SWA page sizes <= MLA page sizes and pads
|
||||
SWA pages to match MLA. This breaks when SWA pages are LARGER than
|
||||
MLA pages (which is always the case for DeepseekV4).
|
||||
|
||||
Our fix: when SWA pages exceed MLA pages, put them in their own
|
||||
separate cache group without padding.
|
||||
"""
|
||||
import sys
|
||||
|
||||
def patch(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "CLAWMINE_PATCH_KV_CACHE" in content:
|
||||
print("Already patched, skipping")
|
||||
return
|
||||
|
||||
# The old code: asserts SWA pages <= MLA pages, then pads SWA to MLA
|
||||
old = """ assert max(sm_page_sizes) <= max(all_page_sizes)
|
||||
|
||||
# Unify page size by padding layers' page_size to the nearest larger page_size.
|
||||
# Compute candidate (nearest larger page_size) for each unique page size.
|
||||
size_to_candidate: dict[int, int] = {}
|
||||
for ps in sm_page_sizes:
|
||||
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
|
||||
# Pad and collect layer names per page size.
|
||||
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
|
||||
current_size = layer_spec.page_size_bytes
|
||||
candidate = size_to_candidate[current_size]
|
||||
if current_size < candidate:
|
||||
object.__setattr__(layer_spec, "page_size_padded", candidate)
|
||||
layers_per_size[candidate].append(layer_name)
|
||||
# NOTE(yifan): for now, inside a UniformKV group, each page_size should
|
||||
# have the same number of layers. This also means we don't need to pad layers
|
||||
# inside a partial-full layer tuple.
|
||||
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
|
||||
num_layers_per_size = len(next(iter(layers_per_size.values())))
|
||||
|
||||
# Split layers inside each UniformKV group for aligned #(layers).
|
||||
# See `_get_kv_cache_groups_uniform_page_size` for more details.
|
||||
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
|
||||
layer_tuples = list(zip(*layers_per_size.values()))
|
||||
for i in range(num_tuple_groups):
|
||||
group_layer_tuples = layer_tuples[i::num_tuple_groups]
|
||||
# Flatten tuples and build dict for from_specs
|
||||
group_layer_names = [
|
||||
name for layer_tuple in group_layer_tuples for name in layer_tuple
|
||||
]
|
||||
group_layer_specs = {
|
||||
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
|
||||
}
|
||||
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
|
||||
assert sub_sm_spec is not None
|
||||
swa_mla_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
layer_names=group_layer_names,
|
||||
kv_cache_spec=sub_sm_spec,
|
||||
)
|
||||
)"""
|
||||
|
||||
# The new code: handle both cases
|
||||
new = """ # CLAWMINE_PATCH_KV_CACHE: Handle DeepseekV4 where SWA page sizes
|
||||
# can be larger than MLA page sizes. Two cases:
|
||||
# 1. All SWA pages <= some MLA page: original padding logic
|
||||
# 2. Some SWA pages > all MLA pages: separate cache group, no padding
|
||||
max_mla_page = max(all_page_sizes)
|
||||
can_pad = max(sm_page_sizes) <= max_mla_page
|
||||
|
||||
if can_pad:
|
||||
# Original logic: pad SWA pages to nearest MLA page
|
||||
size_to_candidate: dict[int, int] = {}
|
||||
for ps in sm_page_sizes:
|
||||
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
|
||||
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
|
||||
current_size = layer_spec.page_size_bytes
|
||||
candidate = size_to_candidate[current_size]
|
||||
if current_size < candidate:
|
||||
object.__setattr__(layer_spec, "page_size_padded", candidate)
|
||||
layers_per_size[candidate].append(layer_name)
|
||||
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
|
||||
num_layers_per_size = len(next(iter(layers_per_size.values())))
|
||||
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
|
||||
layer_tuples = list(zip(*layers_per_size.values()))
|
||||
for i in range(num_tuple_groups):
|
||||
group_layer_tuples = layer_tuples[i::num_tuple_groups]
|
||||
group_layer_names = [
|
||||
name for layer_tuple in group_layer_tuples for name in layer_tuple
|
||||
]
|
||||
group_layer_specs = {
|
||||
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
|
||||
}
|
||||
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
|
||||
assert sub_sm_spec is not None
|
||||
swa_mla_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
layer_names=group_layer_names,
|
||||
kv_cache_spec=sub_sm_spec,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# SWA pages are larger than MLA pages.
|
||||
# Put each SWA layer in its own cache group (no padding needed).
|
||||
# This is the DeepseekV4 Blackwell case where compress_ratio=1
|
||||
# layers have much larger pages than compressed layers.
|
||||
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
|
||||
group_layer_specs = {layer_name: layer_spec}
|
||||
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
|
||||
if sub_sm_spec is not None:
|
||||
swa_mla_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
layer_names=[layer_name],
|
||||
kv_cache_spec=sub_sm_spec,
|
||||
)
|
||||
)"""
|
||||
|
||||
if old not in content:
|
||||
print("ERROR: Could not find the code to patch")
|
||||
sys.exit(1)
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched kv_cache_utils.py for DeepseekV4 SWA page sizes")
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch(sys.argv[1])
|
||||
@@ -1,71 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Patch DeepseekV4SWACache on Blackwell: remove FlashMLA alignment and model_version.
|
||||
|
||||
On Blackwell (SM100+), FlashMLA doesn't work. We use our own CSA/SDPA attention.
|
||||
The SWA cache should use standard fp8 format (not fp8_ds_mla) and no FlashMLA alignment.
|
||||
"""
|
||||
import sys
|
||||
|
||||
def patch(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "CLAWMINE_PATCH_SWA_CACHE" in content:
|
||||
print("Already patched, skipping")
|
||||
return
|
||||
|
||||
# Patch the get_kv_cache_spec method to remove FlashMLA-specific values on Blackwell
|
||||
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return SlidingWindowMLASpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.window_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
||||
model_version="deepseek_v4",
|
||||
)"""
|
||||
|
||||
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# CLAWMINE_PATCH_SWA_CACHE: On Blackwell, no FlashMLA = no 576B alignment
|
||||
# Use standard fp8 format (not fp8_ds_mla), no model_version override
|
||||
from vllm.platforms import current_platform
|
||||
_is_blackwell = (
|
||||
current_platform.get_device_capability() is not None
|
||||
and current_platform.get_device_capability().major >= 10
|
||||
)
|
||||
if _is_blackwell:
|
||||
return SlidingWindowMLASpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.window_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
alignment=None, # No FlashMLA alignment on Blackwell
|
||||
model_version=None, # Don't use 584B deepseek_v4 format
|
||||
)
|
||||
return SlidingWindowMLASpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
sliding_window=self.window_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
||||
model_version="deepseek_v4",
|
||||
)"""
|
||||
|
||||
if old not in content:
|
||||
print("ERROR: Could not find the code to patch in " + path)
|
||||
sys.exit(1)
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched DeepseekV4SWACache for Blackwell")
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch(sys.argv[1])
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/usr/bin
|
||||
# Patch vLLM's linear kernel __init__.py to register the CuTeDSL NVFP4 kernel.
|
||||
# This inserts our kernel at the TOP of the _POSSIBLE_NVFP4_KERNELS list,
|
||||
# so it gets selected first on Blackwell GPUs.
|
||||
|
||||
import sys
|
||||
|
||||
def patch_init(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Add import after the existing flashinfer import block
|
||||
import_line = (
|
||||
"from vllm.model_executor.kernels.linear.nvfp4.cutedsl import (\n"
|
||||
" CuTeDSLNvFp4LinearKernel,\n"
|
||||
")\n"
|
||||
)
|
||||
# Insert after the marlin import block
|
||||
marker = "from vllm.model_executor.kernels.linear.nvfp4.marlin import ("
|
||||
if "CuTeDSLNvFp4LinearKernel" in content:
|
||||
print("CuTeDSL kernel already registered, skipping")
|
||||
return
|
||||
idx = content.find(marker)
|
||||
if idx == -1:
|
||||
print("ERROR: Could not find marlin import marker")
|
||||
sys.exit(1)
|
||||
# Find end of marlin import block
|
||||
end = content.find("\n\n", idx)
|
||||
content = content[:end] + "\n" + import_line + content[end:]
|
||||
|
||||
# Insert CuTeDSLNvFp4LinearKernel at TOP of _POSSIBLE_NVFP4_KERNELS CUDA list
|
||||
old = " PlatformEnum.CUDA: [\n FlashInferCutlassNvFp4LinearKernel,"
|
||||
new = " PlatformEnum.CUDA: [\n CuTeDSLNvFp4LinearKernel,\n FlashInferCutlassNvFp4LinearKernel,"
|
||||
content = content.replace(old, new)
|
||||
|
||||
# Also add to _NVFP4_BACKEND_TO_KERNEL so VLLM_NVFP4_GEMM_BACKEND=cutedsl works
|
||||
old_backend = ' "emulation": EmulationNvFp4LinearKernel,\n}'
|
||||
new_backend = ' "emulation": EmulationNvFp4LinearKernel,\n "cutedsl": CuTeDSLNvFp4LinearKernel,\n}'
|
||||
content = content.replace(old_backend, new_backend)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print("Patched CuTeDSL NVFP4 kernel into", path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
patch_init(sys.argv[1])
|
||||
Reference in New Issue
Block a user