[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)

Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
Maral
2026-04-09 08:50:39 +08:00
committed by GitHub
parent 8332078cfd
commit 2e9034c998
35 changed files with 1710 additions and 904 deletions

View File

@@ -9,11 +9,12 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
@@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=use_cutlass,
use_aiter_and_is_supported=False,
linear_op = init_fp8_linear_kernel(
weight_quant_key=create_fp8_quant_key(
static=True, group_shape=weight_group_shape
),
activation_quant_key=create_fp8_quant_key(
static=False, group_shape=act_quant_group_shape
),
out_dtype=torch.get_default_dtype(),
module_name="build_w8a8_block_fp8_runner",
)
def run():

View File

@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=dtype,
)
for i in range(3)
]
@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
model = test_model_cls(hidden_size, token_num, dtype=dtype)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

View File

@@ -109,6 +109,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=self.vllm_config.model_config.dtype,
)
for i in range(3)
]

View File

@@ -23,6 +23,7 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
@@ -49,6 +50,7 @@ class TestSiluMul(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)
def forward(self, x):
@@ -92,6 +94,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
weight_shape=(hidden_size, intermediate_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)
def forward(self, hidden_states, residual):

View File

@@ -9,7 +9,7 @@ import vllm.config
import vllm.ir.ops
import vllm.plugins
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
from tests.utils import TestFP8Layer
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
@@ -28,19 +28,23 @@ from vllm.config import (
VllmConfig,
)
from vllm.model_executor.kernels.linear import (
AiterFp8BlockScaledMMKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
CutlassFp8BlockScaledMMKernel,
CutlassFP8ScaledMMLinearKernel,
DeepGemmFp8BlockScaledMMKernel,
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
FlashInferFP8ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
TritonFp8BlockScaledMMKernel,
_KernelT,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
create_fp8_quant_key,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
@@ -66,9 +70,12 @@ CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
# Blockwise group shapes
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]
# ROCm kernels
@@ -80,8 +87,8 @@ ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]
KERNEL_GROUPSHAPE_COMBINATIONS = (
@@ -100,8 +107,8 @@ AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Blockwise (no kernel abstraction)
(None, GroupShape(1, 128), True),
# Blockwise
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
]
@@ -110,8 +117,9 @@ class TestModel(torch.nn.Module):
self,
hidden_size: int,
eps: float,
force_kernel: FP8ScaledMMLinearKernel | None,
force_kernel: type[_KernelT] | None,
group_shape: GroupShape,
dtype: torch.dtype,
use_aiter_fusion: bool = False,
use_aiter_quant: bool = False,
*args,
@@ -129,54 +137,42 @@ class TestModel(torch.nn.Module):
is_blockwise = group_shape.is_per_group()
if is_blockwise:
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
block_size = group_shape.col
self.activation_quant_key = create_fp8_quant_key(
static=False, group_shape=group_shape
)
self.fp8_linear_layers = [
TestBlockFP8Layer(
weight_shape=(hidden_size, hidden_size),
group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=use_aiter_quant,
transpose_weights=use_aiter_fusion,
)
for _ in range(3)
]
self.enable_quant_fp8_custom_op = (
False
if use_aiter_quant
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(block_size, block_size)
)
else:
is_static = group_shape == GroupShape.PER_TENSOR
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
self.activation_quant_key = create_fp8_quant_key(
is_static, group_shape=group_shape
)
self.weight_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=group_shape
)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
)
for _ in range(3)
]
# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
transpose_weights=use_aiter_fusion,
input_dtype=dtype,
)
for _ in range(3)
]
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
@@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=False,
use_aiter_quant=False,
)
@@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
)

View File

@@ -66,6 +66,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.attn = Attention(
num_heads=self.num_qo_heads,
@@ -155,6 +156,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)
w = kwargs.get("w")

View File

@@ -74,6 +74,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
# Create kv_b_proj (ColumnParallelLinear) on device.
# Reuse weights from prior model instance when available, because
@@ -190,6 +191,7 @@ class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)
w = kwargs.get("w")

View File

@@ -36,9 +36,9 @@ from vllm.model_executor.kernels.linear import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
@@ -58,7 +58,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
self,
hidden_size: int,
force_kernel: FP8ScaledMMLinearKernel,
dtype: torch.dtype,
**kwargs,
):
super().__init__()
self.silu_and_mul = SiluAndMul()
@@ -68,6 +72,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
force_kernel=force_kernel,
input_dtype=dtype,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
@@ -137,14 +142,20 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
act_quant_key = kFp8Dynamic128Sym
def __init__(self, hidden_size: int, dtype: torch.dtype, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(hidden_size, hidden_size)
)
self.w8a8_block_fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
weight_quant_key=self.weight_quant_key,
activation_quant_key=self.act_quant_key,
input_dtype=dtype,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
@@ -157,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
@@ -324,7 +335,9 @@ def test_fusion_silu_and_mul_quant(
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
model = model_class(
hidden_size=hidden_size, force_kernel=force_kernel, x=x, dtype=dtype
)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)

View File

@@ -246,8 +246,9 @@ def default_vllm_config():
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield
config = VllmConfig()
with set_current_vllm_config(config):
yield config
@pytest.fixture()

View File

@@ -12,8 +12,8 @@ from tests.kernels.quant_utils import (
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import cutlass_scaled_mm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm,
per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm,
)
@@ -202,7 +202,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes are supported by deepgemm
if not should_use_deepgemm_for_fp8_linear(
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
output_dtype=out_dtype, weight_shape=B_fp32.shape, supports_deep_gemm=True
):
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")

View File

@@ -16,6 +16,9 @@ from compressed_tensors.quantization import (
)
from tests.models.utils import check_logprobs_close
from vllm.model_executor.kernels.linear import (
Fp8BlockScaledMMLinearKernel,
)
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
@@ -29,7 +32,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsWNA16,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
@@ -473,16 +475,14 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(
qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp
)
assert isinstance(qkv_proj.scheme.fp8_linear, Fp8BlockScaledMMLinearKernel)
assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2
input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
input_quant_op = qkv_proj.scheme.fp8_linear.quant_fp8
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method in (
input_quant_op.forward_cuda,

View File

@@ -13,6 +13,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.config.model import ModelConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
@@ -406,6 +407,8 @@ def test_fp8_reloading(
"If this is your use case, consider using a restore function like #26327"
)
# Set model config as model_config.dtype is required in Fp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with torch.device("cuda:0"):
config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,

View File

@@ -12,6 +12,7 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm.config.model import ModelConfig
@pytest.fixture(scope="function", autouse=True)
@@ -46,7 +47,7 @@ def _snapshot_download_or_skip(model_id: str) -> str:
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
# TODO: provide a small publicly available test checkpoint
model_path = (
@@ -61,6 +62,8 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
"This test requires a local ModelOpt FP8 checkpoint."
)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
@@ -120,11 +123,13 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_pc_pt_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
model_path = _snapshot_download_or_skip(model_id)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
@@ -181,11 +186,13 @@ def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_pb_wo_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
model_path = _snapshot_download_or_skip(model_id)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):

View File

@@ -43,12 +43,10 @@ from vllm.distributed import (
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.kernels.linear import (
FP8ScaledMMLinearKernel,
_KernelT,
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
)
from vllm.model_executor.model_loader import get_model_loader
@@ -1811,31 +1809,52 @@ class TestFP8Layer(torch.nn.Module):
weight_shape: tuple[int, int],
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
input_dtype: torch.dtype,
out_dtype: torch.dtype | None = None,
transpose_weights: bool = False,
device: torch.device | None = None,
force_kernel: FP8ScaledMMLinearKernel | None = None,
force_kernel: type[_KernelT] | None = None,
):
super().__init__()
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
is_static_activation_scale = activation_quant_key.scale.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
self.input_scale_ub = None
act_scale_desc = activation_quant_key.scale
weight_scale_desc = weight_quant_key.scale
is_block_wise = act_scale_desc.group_shape.is_per_group()
if is_block_wise:
block_size = weight_scale_desc.group_shape.col
weight_scale_shape = weight_shape[0] // block_size
self.weight_scale_inv = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
self.weight_scale = None
if transpose_weights:
self.weight = self.weight.t()
else:
per_tensor_weights = weight_scale_desc.group_shape.is_per_tensor()
is_static_activation_scale = act_scale_desc.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale_inv = None
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = (
torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
)
self.input_scale_ub = None
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
self.kernel = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
force_kernel=force_kernel,
)
@@ -1847,61 +1866,3 @@ class TestFP8Layer(torch.nn.Module):
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.kernel.apply_weights(self, y, bias)
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
# after refactoring W8A8BlockFp8LinearOp.
# https://github.com/vllm-project/vllm/issues/31818
class TestBlockFP8Layer:
"""
Test helper for blockwise FP8 linear operations. Creates random weights
and scales for W8A8BlockFp8LinearOp.
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
abstraction (ScaledMMLinearKernel) for blockwise quantization.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
group_shape: Blockwise quantization group shape.
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
use_aiter_and_is_supported: Whether to use aiter quantization ops.
transpose_weights: Whether to transpose weights after creation.
"""
def __init__(
self,
weight_shape: tuple[int, int],
group_shape: GroupShape,
cutlass_block_fp8_supported: bool = False,
use_aiter_and_is_supported: bool = False,
transpose_weights: bool = False,
):
weight_scale_shape = weight_shape[0] // group_shape[1]
self.weight_scale = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
if transpose_weights:
self.weight = self.weight.t()
self.linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def __call__(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.linear_op.apply(
input=y,
weight=self.weight,
weight_scale=self.weight_scale,
input_scale=self.input_scale,
bias=bias,
)
def is_quant_fp8_enabled(self) -> bool:
return self.linear_op.input_quant_op.enabled()

View File

@@ -1002,11 +1002,11 @@ class VllmBackend:
)
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
hash_content.append(filepath)
try:
with open(filepath) as f:
hash_content.append(f.read())

View File

@@ -19,6 +19,10 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear.base import (
MMLinearKernel,
MMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision import (
MPLinearKernel,
MPLinearLayerConfig,
@@ -52,24 +56,30 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterFp8BlockScaledMMKernel,
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFp8BlockScaledMMKernel,
CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.deep_gemm import (
DeepGemmFp8BlockScaledMMKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
@@ -84,6 +94,7 @@ from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonFp8BlockScaledMMKernel,
TritonInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
@@ -128,6 +139,23 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
],
}
# in priority/performance order (when available)
_POSSIBLE_FP8_BLOCK_KERNELS: dict[
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]]
] = {
PlatformEnum.CUDA: [
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
DeepGemmFp8BlockScaledMMKernel,
CutlassFp8BlockScaledMMKernel,
TritonFp8BlockScaledMMKernel,
],
PlatformEnum.ROCM: [
AiterFp8BlockScaledMMKernel,
TritonFp8BlockScaledMMKernel,
],
}
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
@@ -152,8 +180,10 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
# TODO make all kernels inherit from MMLinearKernel
# then bound _KernelT only to MMLinearKernel
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=MMLinearLayerConfig)
def is_supported_and_can_implement_kernel(
@@ -243,32 +273,61 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel(
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
weight_shape: tuple[int, int],
input_dtype: torch.dtype,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
force_kernel: type[_KernelT] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
)
if activation_quant_key.scale.group_shape.is_per_group():
kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_BLOCK_KERNELS, # type: ignore[misc]
force_kernel=force_kernel,
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
return kernel_type(
scaled_mm_linear_kernel_config,
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
else:
kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc]
force_kernel=force_kernel,
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_scale_ub",
],
)
def init_int8_linear_kernel(
@@ -433,4 +492,7 @@ __all__ = [
"MarlinLinearKernel",
"XPUW4A8IntLinearKernel",
"XPUwNa16LinearKernel",
"_KernelT",
"DeepGemmFp8BlockScaledMMKernel",
"FlashInferFp8DeepGEMMDynamicBlockScaledKernel",
]

View File

@@ -0,0 +1,324 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, ClassVar, Generic, TypeVar
import torch
from typing_extensions import Self
@dataclass
class MMLinearLayerConfig: ...
@dataclass
class Params:
"""Base class for quantized layer parameters.
This class provides a typed interface for accessing quantized weights and scales
from layer modules. It serves as a parameter container that can be extracted from
layers and passed to kernel implementations.
Attributes:
weight: The quantized weight tensor
weight_scale: weight scaling factors
input_scale: Optional input scaling factors
Class Variables:
WEIGHT: Attribute name for weight tensor on the layer module
WEIGHT_SCALE: Attribute name for weight scale tensor on the layer module
INPUT_SCALE: Attribute name for input scale tensor on the layer module
Important:
The string values of WEIGHT, WEIGHT_SCALE, and INPUT_SCALE class variables
MUST match the attribute names used in the corresponding quantization method's
create_weights() implementation.
For example, if FP8LinearMethod.create_weights()
sets layer.weight and layer.weight_scale,
then WEIGHT="weight" and
WEIGHT_SCALE="weight_scale" must be used here.
Usage:
```python
# Extract parameters from a quantized layer
params = Params.from_layer(layer)
# Access typed parameters
output = func(input, params.weight, params.weight_scale)
```
"""
weight: torch.Tensor
weight_scale: torch.Tensor
input_scale: torch.Tensor | None
# Attribute names on the layer
WEIGHT: ClassVar[str] = "weight"
WEIGHT_SCALE: ClassVar[str] = "weight_scale"
INPUT_SCALE: ClassVar[str] = "input_scale"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> Self:
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
)
@dataclass
class FP8Params(Params):
"""FP8 layer parameters with typed fields"""
input_scale_ub: torch.Tensor | None
INPUT_SCALE_UB: ClassVar[str] = "input_scale_ub"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> "FP8Params":
"""Extract parameters from layer"""
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
)
@dataclass
class Int8Params(Params):
"""Int8 layer parameters with typed fields"""
input_zero_point: torch.Tensor | None
azp_adj: torch.Tensor | None
INPUT_ZERO_POINT: ClassVar[str] = "input_zero_point"
AZP_ADJ: ClassVar[str] = "azp_adj"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> "Int8Params":
"""Extract parameters from layer"""
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
input_zero_point=getattr(layer, cls.INPUT_ZERO_POINT, None),
azp_adj=getattr(layer, cls.AZP_ADJ, None),
)
_ParamsT = TypeVar("_ParamsT", bound=Params)
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
class MMLinearKernel(ABC, Generic[_ConfigT, _ParamsT]):
"""Abstract base class for quantized matrix multiplication kernels.
This class provides the interface for implementing custom quantized linear layer
kernels in vLLM. Subclasses should implement specific quantization strategies
(e.g., FP8, INT8) and their corresponding compute kernels.
Generic Type Parameters:
_ConfigT: Configuration type for the kernel (subclass of MMLinearLayerConfig).
Contains kernel-specific settings like quantization keys, dtypes, etc.
_ParamsT: Parameter type for the kernel (subclass of Params).
Defines the quantized weights and scales needed by the kernel.
Typical Usage:
1. Define a config dataclass inheriting from MMLinearLayerConfig
2. Define a params dataclass inheriting from Params (or FP8Params/Int8Params)
3. Subclass MMLinearKernel with your config and params types
4. Implement all abstract methods
5. Register the kernel with the quantization method
Example:
```python
@dataclass
class MyKernelConfig(MMLinearLayerConfig):
static: bool
output_dtype: torch.dtype
@dataclass
class MyKernelParams(FP8Params):
custom_scale: torch.Tensor
CUSTOM_SCALE: ClassVar[str] = "custom_scale"
class MyKernel(MMLinearKernel[MyKernelConfig, MyKernelParams]):
@classmethod
def is_supported(cls, compute_capability=None):
if compute_capability and compute_capability < 90:
return False, "Requires compute capability >= 9.0"
return True, None
@classmethod
def can_implement(cls, config):
if not config.static:
return False, "Only static quantization supported"
return True, None
def process_weights_after_loading(self, layer):
# Preprocess weights for the kernel
params = self._get_layer_params(layer)
processed = preprocess_weights(params.weight)
replace_parameter(layer, params.WEIGHT, processed)
def _get_layer_params(self, layer, **kwargs):
return MyKernelParams.from_layer(layer)
def apply_weights(self, layer, x, bias=None, **kwargs):
params = self._get_layer_params(layer)
# Call your custom kernel
output = my_custom_kernel(x, params.weight, params.weight_scale)
if bias is not None:
output += bias
return output
```
Lifecycle:
1. Kernel selection: is_supported() and can_implement() check compatibility
2. Initialization: __init__() creates kernel instance with config
3. Weight loading: process_weights_after_loading() preprocesses weights
4. Inference: apply_weights() executes the quantized matmul
"""
@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
"""Check if this kernel is supported on the current hardware.
This method checks hardware-level compatibility (e.g., GPU architecture,
compute capability, available instructions). It's called during kernel
selection to filter out kernels that cannot run on the current device.
Args:
compute_capability: GPU compute capability (e.g., 80 for A100, 90 for H100).
If None, should check the current device.
Returns:
A tuple of (is_supported, reason):
- is_supported: True if the kernel can run on this hardware
- reason: If not supported, a string explaining why; otherwise None
"""
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, config: _ConfigT) -> tuple[bool, str | None]:
"""Check if this kernel can implement the given configuration.
This method checks configuration-level compatibility (e.g., quantization
scheme, group sizes, static vs dynamic quantization). It's called after
is_supported() to determine if this kernel can handle the specific
quantization configuration.
Args:
config: The kernel configuration to check
Returns:
A tuple of (can_implement, reason):
- can_implement: True if this kernel supports the config
- reason: If not supported, a string explaining why; otherwise None
```
"""
raise NotImplementedError
def __init__(self, config: _ConfigT) -> None:
"""Initialize the kernel with the given configuration.
Args:
config: Kernel-specific configuration containing settings like
quantization keys, output dtypes, etc.
"""
self.config = config
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process and transform weights after loading from checkpoint.
This method is called once after weights are loaded but before inference.
Use it to preprocess weights into the format required by your kernel
(e.g., reordering, padding, format conversion).
Modifications should be done in-place using replace_parameter() to ensure
the layer's parameters are properly updated.
Args:
layer: The layer module containing the weights to process
Example:
```python
def process_weights_after_loading(self, layer):
params = self._get_layer_params(layer)
# Reorder weights for better memory access
weight_reordered = reorder_weights(params.weight)
replace_parameter(layer, params.WEIGHT, weight_reordered)
```
"""
raise NotImplementedError
# return a covariant type in the subclass
@abstractmethod
def _get_layer_params(self, layer: torch.nn.Module, **kwargs: Any) -> _ParamsT:
"""Extract typed parameters from the layer module.
This internal method retrieves the quantized weights and scales from
the layer as a typed parameter object. Subclasses should typically
delegate to ParamsClass.from_layer().
Args:
layer: The layer module containing the parameters
**kwargs: Additional arguments
Returns:
A typed parameter object containing weights, scales, and other
quantization parameters
Example:
```python
def _get_layer_params(self, layer, **kwargs):
return MyKernelParams.from_layer(layer)
```
"""
raise NotImplementedError
def get_output_padding(self) -> int | None:
"""Get the number of output tokens to pad for this kernel.
Some kernels require input padding for optimal performance.
Override this method to specify padding requirements.
Returns:
Number of tokens to pad, or None for no padding (default)
"""
return None
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
**kwargs: Any,
) -> torch.Tensor:
"""Apply the quantized weights to the input tensor.
This is the main inference method that performs the quantized matrix
multiplication. It should handle input quantization (if needed), call
the underlying kernel, and apply bias.
Args:
layer: The layer module containing the quantized weights
x: Input tensor of shape [..., in_features]
bias: Optional bias tensor of shape [out_features]
**kwargs: Additional kernel-specific arguments
Returns:
Output tensor of shape [..., out_features]
"""
raise NotImplementedError

View File

@@ -0,0 +1,209 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import ClassVar
import torch
from typing_extensions import Self
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_block_strategy,
)
from vllm.model_executor.utils import replace_parameter
from ..base import (
FP8Params,
MMLinearKernel,
)
from .ScaledMMLinearKernel import FP8ScaledMMLinearLayerConfig
@dataclass
class FP8BlockParams(FP8Params):
weight_scale_inv: torch.Tensor | None
weight_scale: torch.Tensor | None
WEIGHT_SCALE_INV: ClassVar[str] = "weight_scale_inv"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> Self:
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale_inv=getattr(layer, cls.WEIGHT_SCALE_INV, None),
weight_scale=getattr(layer, cls.WEIGHT_SCALE, None),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
)
class Fp8BlockScaledMMLinearKernel(
MMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8BlockParams], ABC
):
# Set to False in subclasses that accept BF16 input directly (e.g. FlashInfer)
# and therefore do not need the input quantization step in apply_weights.
apply_input_quant: ClassVar[bool] = True
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
act_scale_descriptor = config.activation_quant_key.scale
self.weight_group_shape = config.weight_quant_key.scale.group_shape
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
use_ue8m0=False,
)
self.use_triton = False
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
act_quant_key = config.activation_quant_key
if act_quant_key.scale.static:
return (
False,
"Only dynamic per token group activation quantization is supported.",
)
return True, None
def _get_layer_params(self, layer: torch.nn.Module, **kwargs) -> FP8BlockParams:
return FP8BlockParams.from_layer(layer)
def process_weights_after_loading(self, layer: torch.nn.Module):
params = self._get_layer_params(layer)
# Fp8LinearMethod registered weight scale
# buffer as weight_scale_inv unlike compressed tensors.
weight_scale = (
params.weight_scale
if params.weight_scale_inv is None
else params.weight_scale_inv
)
scale_attr_name = (
params.WEIGHT_SCALE
if params.weight_scale_inv is None
else params.WEIGHT_SCALE_INV
)
new_weight, new_weight_scale = process_fp8_weight_block_strategy(
params.weight,
weight_scale,
)
replace_parameter(layer, params.WEIGHT, new_weight.data)
replace_parameter(layer, scale_attr_name, new_weight_scale.data)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
params = self._get_layer_params(layer)
weight = params.weight
weight_scale = (
params.weight_scale
if params.weight_scale_inv is None
else params.weight_scale_inv
)
input_scale = params.input_scale
scale_up = params.input_scale_ub
# View input as 2D matrix for fp8 methods
input_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], weight.shape[0]]
if self.apply_input_quant:
q_input, input_scale = self.quant_fp8(
input_2d, input_scale, scale_up, use_triton=self.use_triton
)
else:
q_input = input_2d
# Provide a concrete placeholder so apply_block_scaled_mm args are
# always Tensors. Subclasses with apply_input_quant=False must not
# use As in apply_block_scaled_mm.
input_scale = (
input_scale if input_scale is not None else input_2d.new_ones(1)
)
output = self.apply_block_scaled_mm(
A=q_input,
B=weight,
As=input_scale,
Bs=weight_scale,
)
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype).view(*output_shape)
@abstractmethod
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
class Fp8BlockScaledDynamicMMLinearKernel(Fp8BlockScaledMMLinearKernel, ABC):
"""Dynamic FP8 block-scaled kernel that dispatches at runtime.
Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides
apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.
Subclasses must define:
base_type: The primary kernel class.
fallback_type: The fallback kernel class.
"""
base_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
fallback_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
def __init__(self, config: "FP8ScaledMMLinearLayerConfig") -> None:
super().__init__(config)
self.base = self.base_type(config)
self.fallback = self.fallback_type(config)
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
is_base_supported, reason_1 = cls.base_type.is_supported(compute_capability)
is_fallback_supported, reason_2 = cls.fallback_type.is_supported(
compute_capability
)
if is_base_supported and is_fallback_supported:
return True, None
if not is_base_supported and not is_fallback_supported:
return (
False,
f"base is not supported due to {reason_1}; "
f"fallback is not supported due to {reason_2}",
)
if not is_base_supported:
return False, f"base is not supported due to {reason_1}"
return False, f"fallback is not supported due to {reason_2}"
@classmethod
def can_implement(
cls, config: "FP8ScaledMMLinearLayerConfig"
) -> tuple[bool, str | None]:
can_implement_base, reason_1 = cls.base_type.can_implement(config)
can_implement_fallback, reason_2 = cls.fallback_type.can_implement(config)
if can_implement_base and can_implement_fallback:
return True, None
if not can_implement_base and not can_implement_fallback:
return (
False,
f"base cannot implement due to {reason_1}; "
f"fallback cannot implement due to {reason_2}",
)
if not can_implement_base:
return False, f"base cannot implement due to {reason_1}"
return False, f"fallback cannot implement due to {reason_2}"

View File

@@ -14,14 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
@dataclass
class ScaledMMLinearLayerConfig:
pass
from ..base import MMLinearLayerConfig
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
class Int8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme: bool
is_channelwise: bool
@@ -29,10 +26,12 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
class FP8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
weight_quant_key: QuantKey
activation_quant_key: QuantKey
out_dtype: torch.dtype | None
weight_shape: tuple[int, int]
input_dtype: torch.dtype
out_dtype: torch.dtype
_FP8ParamsT = tuple[
@@ -50,7 +49,7 @@ _Int8ParamsT = tuple[
]
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):

View File

@@ -4,6 +4,9 @@
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
@@ -31,7 +34,6 @@ from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel,
@@ -55,4 +57,5 @@ __all__ = [
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
"Fp8BlockScaledMMLinearKernel",
]

View File

@@ -6,8 +6,15 @@ import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
@@ -107,3 +114,54 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
# b to be [N, K]
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
super().__init__(config)
n, k = config.weight_shape
self.use_triton = (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
)
@classmethod
def is_supported(cls, compute_capability=None):
return (
rocm_aiter_ops.is_linear_enabled(),
"Only supported on ROCm platform \
with aiter package installed.",
)
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
if self.use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
return gemm_a8w8_blockscale_op(
A, B, As, Bs, list(self.weight_group_shape), output_dtype=out_dtype
)

View File

@@ -5,12 +5,19 @@
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
convert_to_channelwise,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
@@ -171,3 +178,143 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
return output.view(*output_shape)
class CutlassFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
act_scale_descriptor = config.activation_quant_key.scale
self.weight_group_shape = config.weight_quant_key.scale.group_shape
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
use_ue8m0=False,
column_major_scales=True,
)
self.is_hopper = current_platform.is_device_capability(90)
@classmethod
def is_supported(cls, compute_capability=None):
if not CUTLASS_BLOCK_FP8_SUPPORTED:
return (
False,
"The device compute capability of"
f"{compute_capability} is not supported.",
)
return True, None
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
if self.is_hopper:
return torch.ops.vllm.padded_cutlass(
A,
B,
As,
Bs,
list(self.weight_group_shape),
out_dtype,
)
else:
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.T,
)
def cutlass_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
scale_b=Bs.T,
)
def _padded_cutlass(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
pad_multiple = 4
dim = qx.shape[0]
padded = (
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
)
has_pad = padded > dim
if has_pad:
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _padded_cutlass_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"padded_cutlass",
_padded_cutlass,
fake_impl=_padded_cutlass_fake,
)

View File

@@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
deepgemm_post_process_fp8_weight_block,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_auto_disable_deep_gemm,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
super().__init__(config)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
act_scale_descriptor = config.activation_quant_key.scale
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.quant_fp8 = QuantFP8(
static=False,
group_shape=act_scale_descriptor.group_shape,
use_ue8m0=self.use_deep_gemm_e8m0,
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
column_major_scales=True,
)
@classmethod
def is_supported(cls, compute_capability=None):
if not current_platform.is_cuda():
return False, "DeepGEMM is only supported on cuda platform"
if not is_deep_gemm_supported():
return False, "Currently, only Hopper and Blackwell GPUs are supported."
return True, None
@classmethod
def can_implement(cls, config):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
if config.out_dtype != torch.bfloat16:
return (False, "Supports only output dtype of bfloat16")
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
model_config = get_current_vllm_config().model_config
if model_config is None:
return False, "Model configuration is required."
model_type = getattr(model_config.hf_text_config, "model_type", None)
if should_auto_disable_deep_gemm(model_type):
return False, f"Should not use deepgemm for model {model_type}"
if not should_use_deepgemm_for_fp8_linear(
config.out_dtype, config.weight_shape
):
return False, "The provided metadata is not supported."
return True, None
def process_weights_after_loading(self, layer):
super().process_weights_after_loading(layer)
params = self._get_layer_params(layer)
assert layer.weight_block_size is not None
if self.is_deep_gemm_supported:
weight_scale_invs = params.weight_scale_inv
scale_attr = (
params.WEIGHT_SCALE_INV
if weight_scale_invs is not None
else params.WEIGHT_SCALE
)
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=params.weight,
ws=weight_scale_invs
if weight_scale_invs is not None
else params.weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=self.use_deep_gemm_e8m0,
)
replace_parameter(layer, params.WEIGHT, dg_weight)
replace_parameter(layer, scale_attr, dg_weight_scale)
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
output = torch.empty(
(A.shape[0], B.shape[0]),
dtype=out_dtype,
device=A.device,
)
torch.ops.vllm.fp8_gemm_nt_op(A, As, B, Bs, output, self.use_deep_gemm_e8m0)
return output
def _fp8_gemm_nt_op(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
def _fp8_gemm_nt_op_fake(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
return None
direct_register_custom_op(
"fp8_gemm_nt_op",
_fp8_gemm_nt_op,
mutates_args=["output"],
fake_impl=_fp8_gemm_nt_op_fake,
)

View File

@@ -2,11 +2,32 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch
import vllm.envs as envs
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.flashinfer import (
flashinfer_fp8_blockscale_gemm,
flashinfer_scaled_fp8_mm,
has_flashinfer,
is_flashinfer_fp8_blockscale_gemm_supported,
should_use_flashinfer_for_blockscale_fp8_gemm,
)
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledDynamicMMLinearKernel,
Fp8BlockScaledMMLinearKernel,
)
from .deep_gemm import DeepGemmFp8BlockScaledMMKernel, fp8_gemm_nt
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
@@ -55,3 +76,256 @@ class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
return flashinfer_scaled_fp8_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
class FlashInferFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
# FlashInfer accepts BF16 input and handles FP8 conversion internally.
apply_input_quant: ClassVar[bool] = False
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
if not should_use_flashinfer_for_blockscale_fp8_gemm(
is_flashinfer_fp8_blockscale_gemm_supported(),
config.out_dtype,
config.input_dtype,
config.weight_quant_key.dtype,
config.weight_shape,
):
return (
False,
"The provided metadata is not supported.",
)
return True, None
@classmethod
def is_supported(cls, compute_capability=None):
if not current_platform.is_cuda():
return False, "only cuda devices are supported."
if not is_flashinfer_fp8_blockscale_gemm_supported():
return False, "FlashInfer block-scale FP8 GEMM is not available."
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
# A is BF16 — FlashInfer handles FP8 conversion internally.
# As is a placeholder (apply_input_quant=False) and is not used here.
return torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
A, # BF16 input
B, # FP8 weight
Bs, # Weight scales
)
class FlashInferFp8DeepGEMMDynamicBlockScaledKernel(
Fp8BlockScaledDynamicMMLinearKernel
):
"""
Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.
Dispatches between two kernels based on input batch size:
- Small batches (M < 32): FlashInfer's swapAB trick for better utilisation.
- Large batches (M >= 32): DeepGEMM for peak throughput.
apply_input_quant is False because FlashInfer accepts BF16 input and
handles FP8 conversion internally. The DeepGEMM branch therefore
quantises BF16→FP8 inside apply_mm via a closure before dispatching to
the DeepGEMM kernel — keeping both branches compatible with the single
BF16 tensor operand list passed by torch.cond.
"""
base_type: ClassVar[type[FlashInferFp8BlockScaledMMKernel]] = (
FlashInferFp8BlockScaledMMKernel
)
fallback_type: ClassVar[type[DeepGemmFp8BlockScaledMMKernel]] = (
DeepGemmFp8BlockScaledMMKernel
)
apply_input_quant: ClassVar[bool] = False
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
super().__init__(config)
self.base: FlashInferFp8BlockScaledMMKernel
self.fallback: DeepGemmFp8BlockScaledMMKernel
def process_weights_after_loading(self, layer: torch.nn.Module):
# DeepGEMM need post-processing; both kernels share the same
# parameter tensor layout so processing once is sufficient.
self.fallback.process_weights_after_loading(layer)
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
group_size = self.weight_group_shape.col
use_deep_gemm_e8m0 = self.fallback.use_deep_gemm_e8m0
return torch.ops.vllm.dynamic_flashinfer_deepgemm_blockscale_gemm(
A, B, Bs, group_size, use_deep_gemm_e8m0
)
def _flashinfer_fp8_blockscale_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return flashinfer_fp8_blockscale_gemm(
input=input,
weight=weight,
weight_scale=weight_scale,
out_dtype=torch.bfloat16,
)
def _flashinfer_fp8_blockscale_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
"""
Required fake/meta implementation for torch.compile graph tracing.
"""
return torch.empty(
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
)
direct_register_custom_op(
"flashinfer_fp8_blockscale_gemm",
_flashinfer_fp8_blockscale_gemm_impl,
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
)
def _dynamic_flashinfer_deepgemm_blockscale_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
This function switches between two optimized kernels based on the input batch size:
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
The conditional logic must use torch.cond() instead of a simple if-else statement
to maintain compatibility with torch.compile graph compilation.
This batch-size-dependent selection is essential for maintaining model accuracy.
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
drop.
Args:
input: Input tensor of shape (batch_size, input_dim) in FP8 format
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
weight_scale: Scale factors for weight quantization (per-group)
group_size: Quantization group size for the weight tensor
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
Returns:
Output tensor of shape (batch_size, output_dim) in bfloat16 format
"""
def run_flashinfer_deepgemm_swapAB(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return flashinfer_fp8_blockscale_gemm(
input=input,
weight=weight,
weight_scale=weight_scale,
out_dtype=torch.bfloat16,
)
def run_deepgemm(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
q_input, input_scale = per_token_group_quant_fp8(
input,
group_size=group_size,
column_major_scales=True,
use_ue8m0=use_deep_gemm_e8m0,
)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
return output
if envs.VLLM_BATCH_INVARIANT:
return run_deepgemm(input, weight, weight_scale)
condition = input.shape[0] < 32
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
# Python conditionals. torch.cond() explicitly registers both code paths in the
# computation graph, allowing torch.compile to capture both branches.
# without torch.cond, the M < 32 condition won't be able to be captured by torch
# compile
return torch.cond(
condition,
run_flashinfer_deepgemm_swapAB,
run_deepgemm,
(input, weight, weight_scale),
)
def _dynamic_flashinfer_deepgemm_blockscale_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Required fake/meta implementation for torch.compile graph tracing.
"""
return torch.empty(
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
)
direct_register_custom_op(
"dynamic_flashinfer_deepgemm_blockscale_gemm",
_dynamic_flashinfer_deepgemm_blockscale_gemm_impl,
fake_impl=_dynamic_flashinfer_deepgemm_blockscale_gemm_fake,
)

View File

@@ -13,7 +13,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
)
from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig,
@@ -150,3 +154,67 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
out -= (x_s * w_s_row * azp_adj).to(x.dtype)
return out
class TritonFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
@classmethod
def is_supported(cls, compute_capability=None):
if not current_platform.is_cuda_alike():
return False, "only cuda like devices are supported."
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
A,
B,
As,
Bs,
list(self.weight_group_shape),
self.config.out_dtype,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
def _w8a8_triton_block_scaled_mm_func(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_triton_block_scaled_mm,
)
return w8a8_triton_block_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
)

View File

@@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
@@ -16,18 +17,16 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
process_fp8_weight_channel_strategy,
process_fp8_weight_tensor_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
@@ -67,6 +66,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
@@ -75,21 +75,18 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(*self.weight_block_size)
)
self.activation_quant_key = create_fp8_quant_key(
static=False, group_shape=self.act_q_group_shape
)
else:
activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
weight_quant_key = weight_quant_key_mapping[self.strategy]
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
self.activation_quant_key = activation_quant_key_mapping[
is_static_input_scheme
]
self.weight_quant_key = weight_quant_key_mapping[self.strategy]
@classmethod
def get_min_capability(cls) -> int:
@@ -146,6 +143,15 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
layer.register_parameter("input_scale", input_scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
@@ -163,10 +169,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
weight, weight_scale = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale
)
input_scale = None
self.fp8_linear.process_weights_after_loading(layer)
layer.input_scale = None
# fp8_linear.process_weights_after_loading applies the post process
# and reassigns the weight and weight_scale buffers to layer attributes.
return
else:
raise ValueError(
@@ -185,8 +193,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else:
layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer)
if hasattr(self, "fp8_linear"):
self.fp8_linear.process_weights_after_loading(layer)
@@ -197,13 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@@ -7,6 +7,7 @@ import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
@@ -93,12 +94,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype()
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.input_dtype = get_current_vllm_config().model_config.dtype
def create_weights(
self,
@@ -149,6 +145,15 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
)
layer.input_scale_ub = input_scale_ub
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: Module) -> None:
# required by torch.compile
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

View File

@@ -10,7 +10,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
@@ -45,13 +45,10 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
maybe_post_process_fp8_weight_block,
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
process_fp8_weight_tensor_strategy_moe,
validate_fp8_block_shape,
@@ -61,6 +58,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
is_layer_skipped,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
@@ -273,12 +271,13 @@ class Fp8LinearMethod(LinearMethodBase):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None
self.use_marlin = False
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
if self.quant_config.use_deep_gemm is not None:
self.use_deep_gemm = self.quant_config.use_deep_gemm
else:
@@ -288,37 +287,26 @@ class Fp8LinearMethod(LinearMethodBase):
self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if self.act_q_static:
activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
if self.block_quant:
weight_quant_key = kFp8Static128BlockSym
else:
weight_quant_key = kFp8StaticTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
if self.block_quant and not self.use_marlin:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
self.activation_quant_key = create_fp8_quant_key(
static=self.act_q_static,
group_shape=GroupShape(1, self.weight_block_size[0]),
)
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(*self.weight_block_size)
)
else:
self.weight_quant_key = kFp8StaticTensorSym
# Use per-token quantization for better perf if dynamic and cutlass
if self.act_q_static:
self.activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
self.activation_quant_key = kFp8DynamicTokenSym
else:
self.activation_quant_key = kFp8DynamicTensorSym
def create_weights(
self,
@@ -384,6 +372,17 @@ class Fp8LinearMethod(LinearMethodBase):
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
def process_weights_after_loading(self, layer: Module) -> None:
if self.use_marlin:
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
@@ -398,13 +397,7 @@ class Fp8LinearMethod(LinearMethodBase):
if self.block_quant:
assert not self.act_q_static
weight, weight_scale_inv = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv
)
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
self.fp8_linear.process_weights_after_loading(layer)
# If checkpoint not serialized fp8, quantize the weights.
else:
@@ -435,9 +428,6 @@ class Fp8LinearMethod(LinearMethodBase):
else:
layer.input_scale = None
if self.block_quant and self.use_deep_gemm:
maybe_post_process_fp8_weight_block(layer)
def apply(
self,
layer: torch.nn.Module,
@@ -449,12 +439,10 @@ class Fp8LinearMethod(LinearMethodBase):
if envs.VLLM_BATCH_INVARIANT:
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
return self.fp8_linear.apply_weights(
layer,
x,
bias,
)
else:
# per-tensor/channel: dequant to BF16 and run GEMM
@@ -483,17 +471,6 @@ class Fp8LinearMethod(LinearMethodBase):
if self.use_marlin:
return self.fp8_linear.apply_weights(layer, x, bias)
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)
@@ -538,6 +515,24 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
initialize_online_processing(layer)
# TODO: remove this check once the following RFC is resolved.
# https://github.com/vllm-project/vllm/issues/33314
# This check is required because Mxfp8OnlineLinearMethod inherits from
# Fp8OnlineLinearMethod but only calls super().create_weights(), so we must
# skip the fp8_linear kernel creation.
if hasattr(self, "mxfp8_linear"):
return
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

View File

@@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.attention import Attention, MLAAttention
@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
)
@@ -78,6 +78,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
is_layer_skipped,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
@@ -86,7 +87,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Static,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import (
@@ -450,12 +450,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8StaticTensorSym,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
def create_weights(
self,
@@ -505,6 +501,15 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8StaticTensorSym,
weight_quant_key=kFp8StaticTensorSym,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
max_w_scale = layer.weight_scale.max()
@@ -536,12 +541,8 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
def create_weights(
self,
@@ -587,6 +588,15 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
@@ -616,12 +626,16 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
self.quant_config = quant_config
block_n, block_k = self._WEIGHT_BLOCK_SIZE
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(block_n, block_k),
act_quant_group_shape=GroupShape(1, block_k),
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
self.activation_quant_key = create_fp8_quant_key(
static=False, group_shape=GroupShape(1, block_k)
)
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(block_n, block_k)
)
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
def create_weights(
self,
@@ -688,8 +702,17 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
self.w8a8_block_fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
# Keep weight in [out, in] layout for Fp8BlockScaledMMLinearKernel.
layer.weight = Parameter(layer.weight.data, requires_grad=False)
scale = layer.weight_scale
@@ -713,13 +736,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
bias=bias,
)
return self.w8a8_block_fp8_linear.apply_weights(layer, x, bias)
class ModelOptFp8MoEMethod(FusedMoEMethodBase):

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
@@ -28,13 +28,9 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
from vllm.model_executor.layers.linear import (
LinearMethodBase,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
@@ -42,7 +38,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
cutlass_fp8_supported,
)
from vllm.model_executor.model_loader.reload.layerwise import (
@@ -51,7 +46,7 @@ from vllm.model_executor.model_loader.reload.layerwise import (
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8
from vllm.utils.deep_gemm import per_block_cast_to_fp8
# ---------------------------------------------------------------------------
# Online FP8 Linear Methods
@@ -64,6 +59,10 @@ class _Fp8OnlineLinearBase(LinearMethodBase):
uses_meta_device: bool = True
def __init__(self):
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
def create_weights(
self,
layer: torch.nn.Module,
@@ -103,18 +102,41 @@ class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
super().__init__()
self.weight_quant_key = kFp8StaticTensorSym
# Use per-token quantization for better perf if dynamic and cutlass
if cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
self.activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
self.activation_quant_key = kFp8DynamicTensorSym
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
@@ -166,19 +188,14 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
super().__init__()
self.weight_block_size = [128, 128]
self.use_deep_gemm = is_deep_gemm_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
self.activation_quant_key = create_fp8_quant_key(
static=False,
group_shape=GroupShape(1, self.weight_block_size[0]),
)
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(*self.weight_block_size)
)
def create_weights(
@@ -202,6 +219,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
)
layer.weight_block_size = self.weight_block_size
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
@@ -213,14 +239,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
layer.weight, block_size=block_size, use_ue8m0=False
)
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
qweight, weight_scale_inv
)
replace_parameter(layer, "weight", qweight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
maybe_post_process_fp8_weight_block(layer)
self.fp8_linear.process_weights_after_loading(layer)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
@@ -234,12 +256,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
assert self.weight_block_size is not None
# Note: batch invariance already handled in the function below
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
return self.fp8_linear.apply_weights(
layer,
x,
bias,
)

View File

@@ -7,6 +7,7 @@ from typing import Any, cast
import torch
from torch.nn import Parameter
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
@@ -57,6 +58,7 @@ class QuarkW8A8Fp8(QuarkScheme):
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
)
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
@classmethod
def get_min_capability(cls) -> int:
@@ -175,7 +177,9 @@ class QuarkW8A8Fp8(QuarkScheme):
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
out_dtype=torch.get_default_dtype(),
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)

View File

@@ -12,15 +12,11 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
all_close_1d,
per_tensor_dequantize,
)
@@ -29,22 +25,14 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_tma_aligned_size,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
transform_sf_into_required_layout,
)
from vllm.utils.flashinfer import (
flashinfer_fp8_blockscale_gemm,
is_flashinfer_fp8_blockscale_gemm_supported,
should_use_flashinfer_for_blockscale_fp8_gemm,
)
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -56,153 +44,6 @@ def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# We need to pass in the is_hopper flag as argument because the function
# current_platform.is_device_capability() is not supported by Torch compiler.
def cutlass_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
scale_b=Bs.T,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
def _w8a8_triton_block_scaled_mm_func(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return w8a8_triton_block_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
)
def _padded_cutlass(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
pad_multiple = 4
dim = qx.shape[0]
padded = (
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
)
has_pad = padded > dim
if has_pad:
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _padded_cutlass_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"padded_cutlass",
_padded_cutlass,
fake_impl=_padded_cutlass_fake,
)
def _fp8_gemm_nt_op(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
def _fp8_gemm_nt_op_fake(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
return None
direct_register_custom_op(
"fp8_gemm_nt_op",
_fp8_gemm_nt_op,
mutates_args=["output"],
fake_impl=_fp8_gemm_nt_op_fake,
)
def _triton_per_token_group_quant_fp8_impl(
x: torch.Tensor,
group_size: int,
@@ -236,362 +77,6 @@ direct_register_custom_op(
)
def _flashinfer_fp8_blockscale_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
This function switches between two optimized kernels based on the input batch size:
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
The conditional logic must use torch.cond() instead of a simple if-else statement
to maintain compatibility with torch.compile graph compilation.
This batch-size-dependent selection is essential for maintaining model accuracy.
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
drop.
Args:
input: Input tensor of shape (batch_size, input_dim) in FP8 format
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
weight_scale: Scale factors for weight quantization (per-group)
group_size: Quantization group size for the weight tensor
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
Returns:
Output tensor of shape (batch_size, output_dim) in bfloat16 format
"""
def run_flashinfer_deepgemm_swapAB(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return flashinfer_fp8_blockscale_gemm(
input=input,
weight=weight,
weight_scale=weight_scale,
out_dtype=torch.bfloat16,
)
def run_deepgemm(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
q_input, input_scale = per_token_group_quant_fp8(
input,
group_size=group_size,
column_major_scales=True,
use_ue8m0=use_deep_gemm_e8m0,
)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
return output
if envs.VLLM_BATCH_INVARIANT:
return run_deepgemm(input, weight, weight_scale)
condition = input.shape[0] < 32
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
# Python conditionals. torch.cond() explicitly registers both code paths in the
# computation graph, allowing torch.compile to capture both branches.
# without torch.cond, the M < 32 condition won't be able to be captured by torch
# compile
return torch.cond(
condition,
run_flashinfer_deepgemm_swapAB,
run_deepgemm,
(input, weight, weight_scale),
)
def _flashinfer_fp8_blockscale_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Required fake/meta implementation for torch.compile graph tracing.
"""
return torch.empty(
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
)
direct_register_custom_op(
"flashinfer_fp8_blockscale_gemm",
_flashinfer_fp8_blockscale_gemm_impl,
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp:
"""
This class executes a Blocked FP8 linear layer using cutlass if supported
and torch.scaled_mm otherwise.
"""
def __init__(
self,
weight_group_shape: GroupShape,
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
use_deep_gemm: bool | None = None,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
if use_deep_gemm is not None:
self.is_deep_gemm_supported = use_deep_gemm
else:
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = (
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported
)
)
self.deepgemm_input_quant_op = (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
use_ue8m0=self.use_deep_gemm_e8m0,
)
if self.is_deep_gemm_supported
else None
)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_flashinfer_for_blockscale_fp8_gemm(
self.is_flashinfer_supported, output_dtype, input_2d, weight
) and should_use_deepgemm_for_fp8_linear(
output_dtype, weight, self.is_deep_gemm_supported
):
output = self._run_flashinfer(input_2d, weight, weight_scale)
elif should_use_deepgemm_for_fp8_linear(
output_dtype, weight, self.is_deep_gemm_supported
):
output = self._run_deepgemm(input_2d, weight, weight_scale)
else:
output = self.w8a8_blockscale_op(
input_2d, weight, weight_scale, input_scale
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_deepgemm(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
torch.ops.vllm.fp8_gemm_nt_op(
q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
)
return output
def _run_cutlass(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
if self.is_hopper:
return torch.ops.vllm.padded_cutlass(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
)
else:
return cutlass_scaled_mm(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
)
def _run_aiter(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
n, k = weight.shape
use_triton = (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
)
if use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
if input_scale is not None:
q_input = input_2d
else:
q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton)
return gemm_a8w8_blockscale_op(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
output_dtype=input_2d.dtype,
)
def _run_triton(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
)
def _run_flashinfer(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
"""
Run FlashInfer FP8 block-scale GEMM.
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
"""
# Now call FlashInfer with BF16 input + FP8 weight, input will be
# quantized with FlashInfer kernel (W8A8)
output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
input=input_2d, # BF16 input
weight=weight, # FP8 weight
weight_scale=weight_scale, # Weight scales
group_size=self.act_quant_group_shape.col,
use_deep_gemm_e8m0=self.use_deep_gemm_e8m0,
)
return output
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
) -> tuple[
Callable[
[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
],
torch.Tensor,
],
QuantFP8,
]:
if use_cutlass:
return self._run_cutlass, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False,
)
)
if use_aiter_and_is_supported:
return self._run_aiter, QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
return self._run_triton, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
)
def input_to_float8(
x: torch.Tensor, dtype: torch.dtype | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -1612,34 +1097,6 @@ def process_fp8_weight_block_strategy(
return weight, weight_scale
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
assert layer.weight_block_size is not None
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear,
)
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight
)
if should_use_deepgemm:
scale_attr = (
"weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale"
)
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data,
ws=getattr(layer, scale_attr).data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
replace_parameter(layer, "weight", dg_weight)
replace_parameter(layer, scale_attr, dg_weight_scale)
def process_fp8_weight_tensor_strategy_moe(
weight: torch.Tensor,
weight_scales: torch.Tensor,

View File

@@ -171,6 +171,16 @@ kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
def create_fp8_quant_key(
static: bool,
group_shape: GroupShape,
symmetric: bool = True,
scale_dtype: torch.dtype = torch.float32,
) -> QuantKey:
scale_desc = ScaleDesc(scale_dtype, static, group_shape)
return QuantKey(FP8_DTYPE, scale_desc, symmetric=symmetric)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent

View File

@@ -413,7 +413,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
def should_use_deepgemm_for_fp8_linear(
output_dtype: torch.dtype,
weight: torch.Tensor,
weight_shape: tuple[int, int],
supports_deep_gemm: bool | None = None,
):
if supports_deep_gemm is None:
@@ -428,8 +428,8 @@ def should_use_deepgemm_for_fp8_linear(
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and weight.shape[0] % N_MULTIPLE == 0
and weight.shape[1] % K_MULTIPLE == 0
and weight_shape[0] % N_MULTIPLE == 0
and weight_shape[1] % K_MULTIPLE == 0
)

View File

@@ -748,8 +748,9 @@ def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
def should_use_flashinfer_for_blockscale_fp8_gemm(
is_flashinfer_supported: bool,
output_dtype: torch.dtype,
input: torch.Tensor,
weight: torch.Tensor,
input_dtype: torch.dtype,
weight_dtype: torch.dtype,
weight_shape: tuple[int, int],
):
if not is_flashinfer_supported:
return False
@@ -760,15 +761,12 @@ def should_use_flashinfer_for_blockscale_fp8_gemm(
N_MULTIPLE = 64
K_MULTIPLE = 128
weight_dtype = weight.dtype
input_dtype = input.dtype
should_use_flashinfer = (
output_dtype == torch.bfloat16
and input_dtype == torch.bfloat16
and weight_dtype == torch.float8_e4m3fn
and weight.shape[0] % N_MULTIPLE == 0
and weight.shape[1] % K_MULTIPLE == 0
and weight_shape[0] % N_MULTIPLE == 0
and weight_shape[1] % K_MULTIPLE == 0
)
return should_use_flashinfer