[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
This commit is contained in:
@@ -9,11 +9,12 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||
import torch
|
||||
|
||||
from vllm.benchmarks.lib.utils import default_vllm_config
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
@@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
weight_group_shape = GroupShape(block_n, block_k)
|
||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||
|
||||
linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=weight_group_shape,
|
||||
act_quant_group_shape=act_quant_group_shape,
|
||||
cutlass_block_fp8_supported=use_cutlass,
|
||||
use_aiter_and_is_supported=False,
|
||||
linear_op = init_fp8_linear_kernel(
|
||||
weight_quant_key=create_fp8_quant_key(
|
||||
static=True, group_shape=weight_group_shape
|
||||
),
|
||||
activation_quant_key=create_fp8_quant_key(
|
||||
static=False, group_shape=act_quant_group_shape
|
||||
),
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name="build_w8a8_block_fp8_runner",
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
def __init__(
|
||||
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
model = test_model_cls(hidden_size, token_num, dtype=dtype)
|
||||
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=self.vllm_config.model_config.dtype,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -49,6 +50,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -92,6 +94,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
weight_shape=(hidden_size, intermediate_size),
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
input_dtype=get_current_vllm_config().model_config.dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
|
||||
@@ -9,7 +9,7 @@ import vllm.config
|
||||
import vllm.ir.ops
|
||||
import vllm.plugins
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import (
|
||||
@@ -28,19 +28,23 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
AiterFp8BlockScaledMMKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
CutlassFp8BlockScaledMMKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
DeepGemmFp8BlockScaledMMKernel,
|
||||
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
TritonFp8BlockScaledMMKernel,
|
||||
_KernelT,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
create_fp8_quant_key,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
@@ -66,9 +70,12 @@ CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
# Blockwise group shapes
|
||||
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
|
||||
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
# ROCm kernels
|
||||
@@ -80,8 +87,8 @@ ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
|
||||
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
KERNEL_GROUPSHAPE_COMBINATIONS = (
|
||||
@@ -100,8 +107,8 @@ AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
||||
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
||||
# Blockwise (no kernel abstraction)
|
||||
(None, GroupShape(1, 128), True),
|
||||
# Blockwise
|
||||
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
|
||||
]
|
||||
|
||||
|
||||
@@ -110,8 +117,9 @@ class TestModel(torch.nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
force_kernel: type[_KernelT] | None,
|
||||
group_shape: GroupShape,
|
||||
dtype: torch.dtype,
|
||||
use_aiter_fusion: bool = False,
|
||||
use_aiter_quant: bool = False,
|
||||
*args,
|
||||
@@ -129,54 +137,42 @@ class TestModel(torch.nn.Module):
|
||||
is_blockwise = group_shape.is_per_group()
|
||||
|
||||
if is_blockwise:
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
block_size = group_shape.col
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
static=False, group_shape=group_shape
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestBlockFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=use_aiter_quant,
|
||||
transpose_weights=use_aiter_fusion,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.enable_quant_fp8_custom_op = (
|
||||
False
|
||||
if use_aiter_quant
|
||||
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(block_size, block_size)
|
||||
)
|
||||
|
||||
else:
|
||||
is_static = group_shape == GroupShape.PER_TENSOR
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
|
||||
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
is_static, group_shape=group_shape
|
||||
)
|
||||
self.weight_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=group_shape
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
# Enable aiter quantization if requested
|
||||
for layer in self.fp8_linear_layers:
|
||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
force_kernel=force_kernel,
|
||||
transpose_weights=use_aiter_fusion,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
# Enable aiter quantization if requested
|
||||
for layer in self.fp8_linear_layers:
|
||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
dtype=dtype,
|
||||
use_aiter_fusion=False,
|
||||
use_aiter_quant=False,
|
||||
)
|
||||
@@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
dtype=dtype,
|
||||
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
|
||||
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
|
||||
)
|
||||
|
||||
@@ -66,6 +66,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_qo_heads,
|
||||
@@ -155,6 +156,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
device=self.device,
|
||||
input_dtype=self.dtype,
|
||||
)
|
||||
|
||||
w = kwargs.get("w")
|
||||
|
||||
@@ -74,6 +74,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
# Create kv_b_proj (ColumnParallelLinear) on device.
|
||||
# Reuse weights from prior model instance when available, because
|
||||
@@ -190,6 +191,7 @@ class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
device=self.device,
|
||||
input_dtype=self.dtype,
|
||||
)
|
||||
|
||||
w = kwargs.get("w")
|
||||
|
||||
@@ -36,9 +36,9 @@ from vllm.model_executor.kernels.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
@@ -58,7 +58,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
|
||||
self,
|
||||
hidden_size: int,
|
||||
force_kernel: FP8ScaledMMLinearKernel,
|
||||
dtype: torch.dtype,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@@ -68,6 +72,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
force_kernel=force_kernel,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
@@ -137,14 +142,20 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, **kwargs):
|
||||
act_quant_key = kFp8Dynamic128Sym
|
||||
|
||||
def __init__(self, hidden_size: int, dtype: torch.dtype, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=GroupShape(1, 128),
|
||||
cutlass_block_fp8_supported=False,
|
||||
use_aiter_and_is_supported=True,
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(hidden_size, hidden_size)
|
||||
)
|
||||
|
||||
self.w8a8_block_fp8_linear = TestFP8Layer(
|
||||
weight_shape=(hidden_size, hidden_size),
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
activation_quant_key=self.act_quant_key,
|
||||
input_dtype=dtype,
|
||||
)
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
@@ -157,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
|
||||
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@@ -324,7 +335,9 @@ def test_fusion_silu_and_mul_quant(
|
||||
|
||||
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||
backend = TestBackend(*passes)
|
||||
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
|
||||
model = model_class(
|
||||
hidden_size=hidden_size, force_kernel=force_kernel, x=x, dtype=dtype
|
||||
)
|
||||
|
||||
# First dimension dynamic
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
@@ -246,8 +246,9 @@ def default_vllm_config():
|
||||
"""
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
yield
|
||||
config = VllmConfig()
|
||||
with set_current_vllm_config(config):
|
||||
yield config
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -12,8 +12,8 @@ from tests.kernels.quant_utils import (
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import cutlass_scaled_mm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm,
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
@@ -202,7 +202,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
|
||||
# only aligned sizes are supported by deepgemm
|
||||
if not should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
|
||||
output_dtype=out_dtype, weight_shape=B_fp32.shape, supports_deep_gemm=True
|
||||
):
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
Fp8BlockScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig,
|
||||
@@ -29,7 +32,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsWNA16,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
@@ -473,16 +475,14 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
|
||||
assert isinstance(
|
||||
qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp
|
||||
)
|
||||
assert isinstance(qkv_proj.scheme.fp8_linear, Fp8BlockScaledMMLinearKernel)
|
||||
|
||||
assert qkv_proj.weight.dtype is fp8_dtype
|
||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||
assert len(qkv_proj.weight.shape) == 2
|
||||
assert len(qkv_proj.weight_scale.shape) == 2
|
||||
|
||||
input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
|
||||
input_quant_op = qkv_proj.scheme.fp8_linear.quant_fp8
|
||||
assert isinstance(input_quant_op, QuantFP8)
|
||||
assert input_quant_op._forward_method in (
|
||||
input_quant_op.forward_cuda,
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
@@ -406,6 +407,8 @@ def test_fp8_reloading(
|
||||
"If this is your use case, consider using a restore function like #26327"
|
||||
)
|
||||
|
||||
# Set model config as model_config.dtype is required in Fp8LinearMethod.
|
||||
default_vllm_config.model_config = ModelConfig()
|
||||
with torch.device("cuda:0"):
|
||||
config = Fp8Config(
|
||||
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
|
||||
@@ -12,6 +12,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.config.model import ModelConfig
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@@ -46,7 +47,7 @@ def _snapshot_download_or_skip(model_id: str) -> str:
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
def test_modelopt_fp8_checkpoint_setup(default_vllm_config, vllm_runner):
|
||||
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
|
||||
# TODO: provide a small publicly available test checkpoint
|
||||
model_path = (
|
||||
@@ -61,6 +62,8 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
"This test requires a local ModelOpt FP8 checkpoint."
|
||||
)
|
||||
|
||||
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
|
||||
default_vllm_config.model_config = ModelConfig()
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
@@ -120,11 +123,13 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
|
||||
def test_modelopt_fp8_pc_pt_checkpoint_setup(default_vllm_config, vllm_runner):
|
||||
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
|
||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
|
||||
model_path = _snapshot_download_or_skip(model_id)
|
||||
|
||||
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
|
||||
default_vllm_config.model_config = ModelConfig()
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
@@ -181,11 +186,13 @@ def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
|
||||
def test_modelopt_fp8_pb_wo_checkpoint_setup(default_vllm_config, vllm_runner):
|
||||
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
|
||||
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
|
||||
model_path = _snapshot_download_or_skip(model_id)
|
||||
|
||||
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
|
||||
default_vllm_config.model_config = ModelConfig()
|
||||
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
|
||||
113
tests/utils.py
113
tests/utils.py
@@ -43,12 +43,10 @@ from vllm.distributed import (
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
_KernelT,
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
@@ -1811,31 +1809,52 @@ class TestFP8Layer(torch.nn.Module):
|
||||
weight_shape: tuple[int, int],
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
input_dtype: torch.dtype,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
transpose_weights: bool = False,
|
||||
device: torch.device | None = None,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None = None,
|
||||
force_kernel: type[_KernelT] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
is_static_activation_scale = activation_quant_key.scale.static
|
||||
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
|
||||
|
||||
self.weight_scale = torch.rand(
|
||||
weight_scale_shape, dtype=torch.float32, device=device
|
||||
)
|
||||
self.input_scale = (
|
||||
torch.rand(1, dtype=torch.float32, device=device)
|
||||
if is_static_activation_scale
|
||||
else None
|
||||
)
|
||||
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
|
||||
self.input_scale_ub = None
|
||||
act_scale_desc = activation_quant_key.scale
|
||||
weight_scale_desc = weight_quant_key.scale
|
||||
is_block_wise = act_scale_desc.group_shape.is_per_group()
|
||||
if is_block_wise:
|
||||
block_size = weight_scale_desc.group_shape.col
|
||||
weight_scale_shape = weight_shape[0] // block_size
|
||||
self.weight_scale_inv = torch.rand(
|
||||
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
|
||||
)
|
||||
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
|
||||
self.input_scale = None
|
||||
self.weight_scale = None
|
||||
if transpose_weights:
|
||||
self.weight = self.weight.t()
|
||||
else:
|
||||
per_tensor_weights = weight_scale_desc.group_shape.is_per_tensor()
|
||||
is_static_activation_scale = act_scale_desc.static
|
||||
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
|
||||
self.weight_scale_inv = None
|
||||
self.weight_scale = torch.rand(
|
||||
weight_scale_shape, dtype=torch.float32, device=device
|
||||
)
|
||||
self.input_scale = (
|
||||
torch.rand(1, dtype=torch.float32, device=device)
|
||||
if is_static_activation_scale
|
||||
else None
|
||||
)
|
||||
self.weight = (
|
||||
torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
|
||||
)
|
||||
self.input_scale_ub = None
|
||||
|
||||
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
|
||||
|
||||
self.kernel = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
weight_shape=weight_shape,
|
||||
input_dtype=input_dtype,
|
||||
out_dtype=out_dtype,
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
@@ -1847,61 +1866,3 @@ class TestFP8Layer(torch.nn.Module):
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(self, y, bias)
|
||||
|
||||
|
||||
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
|
||||
# after refactoring W8A8BlockFp8LinearOp.
|
||||
# https://github.com/vllm-project/vllm/issues/31818
|
||||
class TestBlockFP8Layer:
|
||||
"""
|
||||
Test helper for blockwise FP8 linear operations. Creates random weights
|
||||
and scales for W8A8BlockFp8LinearOp.
|
||||
|
||||
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
|
||||
abstraction (ScaledMMLinearKernel) for blockwise quantization.
|
||||
|
||||
Args:
|
||||
weight_shape: Shape of the weight tensor (out_features, in_features).
|
||||
group_shape: Blockwise quantization group shape.
|
||||
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
|
||||
use_aiter_and_is_supported: Whether to use aiter quantization ops.
|
||||
transpose_weights: Whether to transpose weights after creation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = False,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
transpose_weights: bool = False,
|
||||
):
|
||||
weight_scale_shape = weight_shape[0] // group_shape[1]
|
||||
self.weight_scale = torch.rand(
|
||||
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
|
||||
)
|
||||
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
|
||||
self.input_scale = None
|
||||
if transpose_weights:
|
||||
self.weight = self.weight.t()
|
||||
|
||||
self.linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.linear_op.apply(
|
||||
input=y,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
input_scale=self.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
return self.linear_op.input_quant_op.enabled()
|
||||
|
||||
@@ -1002,11 +1002,11 @@ class VllmBackend:
|
||||
)
|
||||
hash_content = []
|
||||
for filepath in forward_code_files:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(filepath)
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
|
||||
@@ -19,6 +19,10 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear.base import (
|
||||
MMLinearKernel,
|
||||
MMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
@@ -52,24 +56,30 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm import (
|
||||
Fp8BlockScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterFp8BlockScaledMMKernel,
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
CutlassFp8BlockScaledMMKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.deep_gemm import (
|
||||
DeepGemmFp8BlockScaledMMKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
|
||||
@@ -84,6 +94,7 @@ from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonFp8BlockScaledMMKernel,
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
|
||||
@@ -128,6 +139,23 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_FP8_BLOCK_KERNELS: dict[
|
||||
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]]
|
||||
] = {
|
||||
PlatformEnum.CUDA: [
|
||||
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
|
||||
DeepGemmFp8BlockScaledMMKernel,
|
||||
CutlassFp8BlockScaledMMKernel,
|
||||
TritonFp8BlockScaledMMKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
AiterFp8BlockScaledMMKernel,
|
||||
TritonFp8BlockScaledMMKernel,
|
||||
],
|
||||
}
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
@@ -152,8 +180,10 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
# TODO make all kernels inherit from MMLinearKernel
|
||||
# then bound _KernelT only to MMLinearKernel
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=MMLinearLayerConfig)
|
||||
|
||||
|
||||
def is_supported_and_can_implement_kernel(
|
||||
@@ -243,32 +273,61 @@ def choose_scaled_mm_linear_kernel(
|
||||
def init_fp8_linear_kernel(
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
weight_shape: tuple[int, int],
|
||||
input_dtype: torch.dtype,
|
||||
out_dtype: torch.dtype,
|
||||
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
|
||||
force_kernel: type[_KernelT] | None = None,
|
||||
module_name: str | None = None,
|
||||
) -> FP8ScaledMMLinearKernel:
|
||||
) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel:
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
weight_quant_key=weight_quant_key,
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_shape=weight_shape,
|
||||
input_dtype=input_dtype,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
|
||||
)
|
||||
if activation_quant_key.scale.group_shape.is_per_group():
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
config=scaled_mm_linear_kernel_config,
|
||||
possible_kernels=_POSSIBLE_FP8_BLOCK_KERNELS, # type: ignore[misc]
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
if module_name:
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
if module_name:
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
return kernel_type(
|
||||
scaled_mm_linear_kernel_config,
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
scaled_mm_linear_kernel_config,
|
||||
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
|
||||
)
|
||||
else:
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
config=scaled_mm_linear_kernel_config,
|
||||
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc]
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
if module_name:
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
scaled_mm_linear_kernel_config,
|
||||
layer_param_names=[
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_scale_ub",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def init_int8_linear_kernel(
|
||||
@@ -433,4 +492,7 @@ __all__ = [
|
||||
"MarlinLinearKernel",
|
||||
"XPUW4A8IntLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
"_KernelT",
|
||||
"DeepGemmFp8BlockScaledMMKernel",
|
||||
"FlashInferFp8DeepGEMMDynamicBlockScaledKernel",
|
||||
]
|
||||
|
||||
324
vllm/model_executor/kernels/linear/base.py
Normal file
324
vllm/model_executor/kernels/linear/base.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMLinearLayerConfig: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
"""Base class for quantized layer parameters.
|
||||
|
||||
This class provides a typed interface for accessing quantized weights and scales
|
||||
from layer modules. It serves as a parameter container that can be extracted from
|
||||
layers and passed to kernel implementations.
|
||||
|
||||
Attributes:
|
||||
weight: The quantized weight tensor
|
||||
weight_scale: weight scaling factors
|
||||
input_scale: Optional input scaling factors
|
||||
|
||||
Class Variables:
|
||||
WEIGHT: Attribute name for weight tensor on the layer module
|
||||
WEIGHT_SCALE: Attribute name for weight scale tensor on the layer module
|
||||
INPUT_SCALE: Attribute name for input scale tensor on the layer module
|
||||
|
||||
Important:
|
||||
The string values of WEIGHT, WEIGHT_SCALE, and INPUT_SCALE class variables
|
||||
MUST match the attribute names used in the corresponding quantization method's
|
||||
create_weights() implementation.
|
||||
For example, if FP8LinearMethod.create_weights()
|
||||
sets layer.weight and layer.weight_scale,
|
||||
then WEIGHT="weight" and
|
||||
WEIGHT_SCALE="weight_scale" must be used here.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
# Extract parameters from a quantized layer
|
||||
params = Params.from_layer(layer)
|
||||
|
||||
# Access typed parameters
|
||||
output = func(input, params.weight, params.weight_scale)
|
||||
```
|
||||
"""
|
||||
|
||||
weight: torch.Tensor
|
||||
weight_scale: torch.Tensor
|
||||
input_scale: torch.Tensor | None
|
||||
|
||||
# Attribute names on the layer
|
||||
WEIGHT: ClassVar[str] = "weight"
|
||||
WEIGHT_SCALE: ClassVar[str] = "weight_scale"
|
||||
INPUT_SCALE: ClassVar[str] = "input_scale"
|
||||
|
||||
@classmethod
|
||||
def from_layer(cls, layer: torch.nn.Module) -> Self:
|
||||
return cls(
|
||||
weight=getattr(layer, cls.WEIGHT),
|
||||
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
|
||||
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8Params(Params):
|
||||
"""FP8 layer parameters with typed fields"""
|
||||
|
||||
input_scale_ub: torch.Tensor | None
|
||||
|
||||
INPUT_SCALE_UB: ClassVar[str] = "input_scale_ub"
|
||||
|
||||
@classmethod
|
||||
def from_layer(cls, layer: torch.nn.Module) -> "FP8Params":
|
||||
"""Extract parameters from layer"""
|
||||
return cls(
|
||||
weight=getattr(layer, cls.WEIGHT),
|
||||
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
|
||||
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8Params(Params):
|
||||
"""Int8 layer parameters with typed fields"""
|
||||
|
||||
input_zero_point: torch.Tensor | None
|
||||
azp_adj: torch.Tensor | None
|
||||
|
||||
INPUT_ZERO_POINT: ClassVar[str] = "input_zero_point"
|
||||
AZP_ADJ: ClassVar[str] = "azp_adj"
|
||||
|
||||
@classmethod
|
||||
def from_layer(cls, layer: torch.nn.Module) -> "Int8Params":
|
||||
"""Extract parameters from layer"""
|
||||
return cls(
|
||||
weight=getattr(layer, cls.WEIGHT),
|
||||
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
|
||||
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||
input_zero_point=getattr(layer, cls.INPUT_ZERO_POINT, None),
|
||||
azp_adj=getattr(layer, cls.AZP_ADJ, None),
|
||||
)
|
||||
|
||||
|
||||
_ParamsT = TypeVar("_ParamsT", bound=Params)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
|
||||
|
||||
|
||||
class MMLinearKernel(ABC, Generic[_ConfigT, _ParamsT]):
|
||||
"""Abstract base class for quantized matrix multiplication kernels.
|
||||
|
||||
This class provides the interface for implementing custom quantized linear layer
|
||||
kernels in vLLM. Subclasses should implement specific quantization strategies
|
||||
(e.g., FP8, INT8) and their corresponding compute kernels.
|
||||
|
||||
Generic Type Parameters:
|
||||
_ConfigT: Configuration type for the kernel (subclass of MMLinearLayerConfig).
|
||||
Contains kernel-specific settings like quantization keys, dtypes, etc.
|
||||
_ParamsT: Parameter type for the kernel (subclass of Params).
|
||||
Defines the quantized weights and scales needed by the kernel.
|
||||
|
||||
Typical Usage:
|
||||
1. Define a config dataclass inheriting from MMLinearLayerConfig
|
||||
2. Define a params dataclass inheriting from Params (or FP8Params/Int8Params)
|
||||
3. Subclass MMLinearKernel with your config and params types
|
||||
4. Implement all abstract methods
|
||||
5. Register the kernel with the quantization method
|
||||
|
||||
Example:
|
||||
```python
|
||||
@dataclass
|
||||
class MyKernelConfig(MMLinearLayerConfig):
|
||||
static: bool
|
||||
output_dtype: torch.dtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyKernelParams(FP8Params):
|
||||
custom_scale: torch.Tensor
|
||||
CUSTOM_SCALE: ClassVar[str] = "custom_scale"
|
||||
|
||||
|
||||
class MyKernel(MMLinearKernel[MyKernelConfig, MyKernelParams]):
|
||||
@classmethod
|
||||
def is_supported(cls, compute_capability=None):
|
||||
if compute_capability and compute_capability < 90:
|
||||
return False, "Requires compute capability >= 9.0"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config):
|
||||
if not config.static:
|
||||
return False, "Only static quantization supported"
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
# Preprocess weights for the kernel
|
||||
params = self._get_layer_params(layer)
|
||||
processed = preprocess_weights(params.weight)
|
||||
replace_parameter(layer, params.WEIGHT, processed)
|
||||
|
||||
def _get_layer_params(self, layer, **kwargs):
|
||||
return MyKernelParams.from_layer(layer)
|
||||
|
||||
def apply_weights(self, layer, x, bias=None, **kwargs):
|
||||
params = self._get_layer_params(layer)
|
||||
# Call your custom kernel
|
||||
output = my_custom_kernel(x, params.weight, params.weight_scale)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output
|
||||
```
|
||||
|
||||
Lifecycle:
|
||||
1. Kernel selection: is_supported() and can_implement() check compatibility
|
||||
2. Initialization: __init__() creates kernel instance with config
|
||||
3. Weight loading: process_weights_after_loading() preprocesses weights
|
||||
4. Inference: apply_weights() executes the quantized matmul
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if this kernel is supported on the current hardware.
|
||||
|
||||
This method checks hardware-level compatibility (e.g., GPU architecture,
|
||||
compute capability, available instructions). It's called during kernel
|
||||
selection to filter out kernels that cannot run on the current device.
|
||||
|
||||
Args:
|
||||
compute_capability: GPU compute capability (e.g., 80 for A100, 90 for H100).
|
||||
If None, should check the current device.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_supported, reason):
|
||||
- is_supported: True if the kernel can run on this hardware
|
||||
- reason: If not supported, a string explaining why; otherwise None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, config: _ConfigT) -> tuple[bool, str | None]:
|
||||
"""Check if this kernel can implement the given configuration.
|
||||
|
||||
This method checks configuration-level compatibility (e.g., quantization
|
||||
scheme, group sizes, static vs dynamic quantization). It's called after
|
||||
is_supported() to determine if this kernel can handle the specific
|
||||
quantization configuration.
|
||||
|
||||
Args:
|
||||
config: The kernel configuration to check
|
||||
|
||||
Returns:
|
||||
A tuple of (can_implement, reason):
|
||||
- can_implement: True if this kernel supports the config
|
||||
- reason: If not supported, a string explaining why; otherwise None
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: _ConfigT) -> None:
|
||||
"""Initialize the kernel with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Kernel-specific configuration containing settings like
|
||||
quantization keys, output dtypes, etc.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Process and transform weights after loading from checkpoint.
|
||||
|
||||
This method is called once after weights are loaded but before inference.
|
||||
Use it to preprocess weights into the format required by your kernel
|
||||
(e.g., reordering, padding, format conversion).
|
||||
|
||||
Modifications should be done in-place using replace_parameter() to ensure
|
||||
the layer's parameters are properly updated.
|
||||
|
||||
Args:
|
||||
layer: The layer module containing the weights to process
|
||||
|
||||
Example:
|
||||
```python
|
||||
def process_weights_after_loading(self, layer):
|
||||
params = self._get_layer_params(layer)
|
||||
# Reorder weights for better memory access
|
||||
weight_reordered = reorder_weights(params.weight)
|
||||
replace_parameter(layer, params.WEIGHT, weight_reordered)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# return a covariant type in the subclass
|
||||
@abstractmethod
|
||||
def _get_layer_params(self, layer: torch.nn.Module, **kwargs: Any) -> _ParamsT:
|
||||
"""Extract typed parameters from the layer module.
|
||||
|
||||
This internal method retrieves the quantized weights and scales from
|
||||
the layer as a typed parameter object. Subclasses should typically
|
||||
delegate to ParamsClass.from_layer().
|
||||
|
||||
Args:
|
||||
layer: The layer module containing the parameters
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
A typed parameter object containing weights, scales, and other
|
||||
quantization parameters
|
||||
|
||||
Example:
|
||||
```python
|
||||
def _get_layer_params(self, layer, **kwargs):
|
||||
return MyKernelParams.from_layer(layer)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_padding(self) -> int | None:
|
||||
"""Get the number of output tokens to pad for this kernel.
|
||||
|
||||
Some kernels require input padding for optimal performance.
|
||||
Override this method to specify padding requirements.
|
||||
|
||||
Returns:
|
||||
Number of tokens to pad, or None for no padding (default)
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
"""Apply the quantized weights to the input tensor.
|
||||
|
||||
This is the main inference method that performs the quantized matrix
|
||||
multiplication. It should handle input quantization (if needed), call
|
||||
the underlying kernel, and apply bias.
|
||||
|
||||
Args:
|
||||
layer: The layer module containing the quantized weights
|
||||
x: Input tensor of shape [..., in_features]
|
||||
bias: Optional bias tensor of shape [out_features]
|
||||
**kwargs: Additional kernel-specific arguments
|
||||
|
||||
Returns:
|
||||
Output tensor of shape [..., out_features]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,209 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_weight_block_strategy,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
|
||||
from ..base import (
|
||||
FP8Params,
|
||||
MMLinearKernel,
|
||||
)
|
||||
from .ScaledMMLinearKernel import FP8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8BlockParams(FP8Params):
|
||||
weight_scale_inv: torch.Tensor | None
|
||||
weight_scale: torch.Tensor | None
|
||||
|
||||
WEIGHT_SCALE_INV: ClassVar[str] = "weight_scale_inv"
|
||||
|
||||
@classmethod
|
||||
def from_layer(cls, layer: torch.nn.Module) -> Self:
|
||||
return cls(
|
||||
weight=getattr(layer, cls.WEIGHT),
|
||||
weight_scale_inv=getattr(layer, cls.WEIGHT_SCALE_INV, None),
|
||||
weight_scale=getattr(layer, cls.WEIGHT_SCALE, None),
|
||||
input_scale=getattr(layer, cls.INPUT_SCALE, None),
|
||||
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
|
||||
)
|
||||
|
||||
|
||||
class Fp8BlockScaledMMLinearKernel(
|
||||
MMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8BlockParams], ABC
|
||||
):
|
||||
# Set to False in subclasses that accept BF16 input directly (e.g. FlashInfer)
|
||||
# and therefore do not need the input quantization step in apply_weights.
|
||||
apply_input_quant: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
|
||||
super().__init__(config)
|
||||
act_scale_descriptor = config.activation_quant_key.scale
|
||||
self.weight_group_shape = config.weight_quant_key.scale.group_shape
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=act_scale_descriptor.static,
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
num_token_padding=self.get_output_padding(),
|
||||
use_ue8m0=False,
|
||||
)
|
||||
self.use_triton = False
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||
act_quant_key = config.activation_quant_key
|
||||
if act_quant_key.scale.static:
|
||||
return (
|
||||
False,
|
||||
"Only dynamic per token group activation quantization is supported.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def _get_layer_params(self, layer: torch.nn.Module, **kwargs) -> FP8BlockParams:
|
||||
return FP8BlockParams.from_layer(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
params = self._get_layer_params(layer)
|
||||
# Fp8LinearMethod registered weight scale
|
||||
# buffer as weight_scale_inv unlike compressed tensors.
|
||||
weight_scale = (
|
||||
params.weight_scale
|
||||
if params.weight_scale_inv is None
|
||||
else params.weight_scale_inv
|
||||
)
|
||||
scale_attr_name = (
|
||||
params.WEIGHT_SCALE
|
||||
if params.weight_scale_inv is None
|
||||
else params.WEIGHT_SCALE_INV
|
||||
)
|
||||
new_weight, new_weight_scale = process_fp8_weight_block_strategy(
|
||||
params.weight,
|
||||
weight_scale,
|
||||
)
|
||||
|
||||
replace_parameter(layer, params.WEIGHT, new_weight.data)
|
||||
replace_parameter(layer, scale_attr_name, new_weight_scale.data)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
out_dtype = self.config.out_dtype
|
||||
params = self._get_layer_params(layer)
|
||||
weight = params.weight
|
||||
weight_scale = (
|
||||
params.weight_scale
|
||||
if params.weight_scale_inv is None
|
||||
else params.weight_scale_inv
|
||||
)
|
||||
input_scale = params.input_scale
|
||||
scale_up = params.input_scale_ub
|
||||
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], weight.shape[0]]
|
||||
|
||||
if self.apply_input_quant:
|
||||
q_input, input_scale = self.quant_fp8(
|
||||
input_2d, input_scale, scale_up, use_triton=self.use_triton
|
||||
)
|
||||
else:
|
||||
q_input = input_2d
|
||||
# Provide a concrete placeholder so apply_block_scaled_mm args are
|
||||
# always Tensors. Subclasses with apply_input_quant=False must not
|
||||
# use As in apply_block_scaled_mm.
|
||||
input_scale = (
|
||||
input_scale if input_scale is not None else input_2d.new_ones(1)
|
||||
)
|
||||
|
||||
output = self.apply_block_scaled_mm(
|
||||
A=q_input,
|
||||
B=weight,
|
||||
As=input_scale,
|
||||
Bs=weight_scale,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=out_dtype).view(*output_shape)
|
||||
|
||||
@abstractmethod
|
||||
def apply_block_scaled_mm(
|
||||
self,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Fp8BlockScaledDynamicMMLinearKernel(Fp8BlockScaledMMLinearKernel, ABC):
|
||||
"""Dynamic FP8 block-scaled kernel that dispatches at runtime.
|
||||
|
||||
Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides
|
||||
apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.
|
||||
|
||||
Subclasses must define:
|
||||
base_type: The primary kernel class.
|
||||
fallback_type: The fallback kernel class.
|
||||
"""
|
||||
|
||||
base_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
|
||||
fallback_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
|
||||
|
||||
def __init__(self, config: "FP8ScaledMMLinearLayerConfig") -> None:
|
||||
super().__init__(config)
|
||||
self.base = self.base_type(config)
|
||||
self.fallback = self.fallback_type(config)
|
||||
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
is_base_supported, reason_1 = cls.base_type.is_supported(compute_capability)
|
||||
is_fallback_supported, reason_2 = cls.fallback_type.is_supported(
|
||||
compute_capability
|
||||
)
|
||||
if is_base_supported and is_fallback_supported:
|
||||
return True, None
|
||||
if not is_base_supported and not is_fallback_supported:
|
||||
return (
|
||||
False,
|
||||
f"base is not supported due to {reason_1}; "
|
||||
f"fallback is not supported due to {reason_2}",
|
||||
)
|
||||
if not is_base_supported:
|
||||
return False, f"base is not supported due to {reason_1}"
|
||||
return False, f"fallback is not supported due to {reason_2}"
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, config: "FP8ScaledMMLinearLayerConfig"
|
||||
) -> tuple[bool, str | None]:
|
||||
can_implement_base, reason_1 = cls.base_type.can_implement(config)
|
||||
can_implement_fallback, reason_2 = cls.fallback_type.can_implement(config)
|
||||
if can_implement_base and can_implement_fallback:
|
||||
return True, None
|
||||
if not can_implement_base and not can_implement_fallback:
|
||||
return (
|
||||
False,
|
||||
f"base cannot implement due to {reason_1}; "
|
||||
f"fallback cannot implement due to {reason_2}",
|
||||
)
|
||||
if not can_implement_base:
|
||||
return False, f"base cannot implement due to {reason_1}"
|
||||
return False, f"fallback cannot implement due to {reason_2}"
|
||||
@@ -14,14 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
pass
|
||||
from ..base import MMLinearLayerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
class Int8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
|
||||
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
|
||||
is_static_input_scheme: bool
|
||||
is_channelwise: bool
|
||||
@@ -29,10 +26,12 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
class FP8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
|
||||
weight_quant_key: QuantKey
|
||||
activation_quant_key: QuantKey
|
||||
out_dtype: torch.dtype | None
|
||||
weight_shape: tuple[int, int]
|
||||
input_dtype: torch.dtype
|
||||
out_dtype: torch.dtype
|
||||
|
||||
|
||||
_FP8ParamsT = tuple[
|
||||
@@ -50,7 +49,7 @@ _Int8ParamsT = tuple[
|
||||
]
|
||||
|
||||
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel import (
|
||||
Fp8BlockScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
@@ -31,7 +34,6 @@ from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
@@ -55,4 +57,5 @@ __all__ = [
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
"Fp8BlockScaledMMLinearKernel",
|
||||
]
|
||||
|
||||
@@ -6,8 +6,15 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .BlockScaledMMLinearKernel import (
|
||||
Fp8BlockScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
@@ -107,3 +114,54 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
# b to be [N, K]
|
||||
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
|
||||
|
||||
class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
|
||||
super().__init__(config)
|
||||
n, k = config.weight_shape
|
||||
|
||||
self.use_triton = (
|
||||
not current_platform.is_fp8_fnuz()
|
||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, compute_capability=None):
|
||||
return (
|
||||
rocm_aiter_ops.is_linear_enabled(),
|
||||
"Only supported on ROCm platform \
|
||||
with aiter package installed.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||
can_implement_base, reason = super().can_implement(config)
|
||||
if not can_implement_base:
|
||||
return can_implement_base, reason
|
||||
|
||||
act_quant_desc = config.activation_quant_key.scale
|
||||
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||
return (
|
||||
False,
|
||||
"Supports only dynamic per token group activation "
|
||||
"quantization with group_shape=(1,128).",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def apply_block_scaled_mm(
|
||||
self,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out_dtype = self.config.out_dtype
|
||||
if self.use_triton:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||
else:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
||||
|
||||
return gemm_a8w8_blockscale_op(
|
||||
A, B, As, Bs, list(self.weight_group_shape), output_dtype=out_dtype
|
||||
)
|
||||
|
||||
@@ -5,12 +5,19 @@
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
@@ -171,3 +178,143 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
class CutlassFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
|
||||
super().__init__(config)
|
||||
act_scale_descriptor = config.activation_quant_key.scale
|
||||
self.weight_group_shape = config.weight_quant_key.scale.group_shape
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=act_scale_descriptor.static,
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
num_token_padding=self.get_output_padding(),
|
||||
use_ue8m0=False,
|
||||
column_major_scales=True,
|
||||
)
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, compute_capability=None):
|
||||
if not CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||
return (
|
||||
False,
|
||||
"The device compute capability of"
|
||||
f"{compute_capability} is not supported.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||
can_implement_base, reason = super().can_implement(config)
|
||||
if not can_implement_base:
|
||||
return can_implement_base, reason
|
||||
|
||||
act_quant_desc = config.activation_quant_key.scale
|
||||
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||
return (
|
||||
False,
|
||||
"Supports only dynamic per token group activation "
|
||||
"quantization with group_shape=(1,128).",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def apply_block_scaled_mm(
|
||||
self,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out_dtype = self.config.out_dtype
|
||||
if self.is_hopper:
|
||||
return torch.ops.vllm.padded_cutlass(
|
||||
A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
list(self.weight_group_shape),
|
||||
out_dtype,
|
||||
)
|
||||
else:
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.T,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.T,
|
||||
)
|
||||
|
||||
|
||||
def _padded_cutlass(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
pad_multiple = 4
|
||||
dim = qx.shape[0]
|
||||
padded = (
|
||||
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
|
||||
)
|
||||
|
||||
has_pad = padded > dim
|
||||
|
||||
if has_pad:
|
||||
padded_shape = [padded, *qx.shape[1:]]
|
||||
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
|
||||
padded_qx[0 : qx.shape[0], ...].copy_(qx)
|
||||
|
||||
padded_x_scale_shape = [*x_scale.shape[1:], padded]
|
||||
padded_x_scale = torch.ones(
|
||||
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
|
||||
).permute(-1, -2)
|
||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||
|
||||
output = cutlass_scaled_mm(
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
return output[0 : qx.shape[0], ...]
|
||||
else:
|
||||
return cutlass_scaled_mm(
|
||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
|
||||
|
||||
def _padded_cutlass_fake(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"padded_cutlass",
|
||||
_padded_cutlass,
|
||||
fake_impl=_padded_cutlass_fake,
|
||||
)
|
||||
|
||||
156
vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py
Normal file
156
vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
deepgemm_post_process_fp8_weight_block,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_auto_disable_deep_gemm,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .BlockScaledMMLinearKernel import (
|
||||
Fp8BlockScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
|
||||
super().__init__(config)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
act_scale_descriptor = config.activation_quant_key.scale
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=False,
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
use_ue8m0=self.use_deep_gemm_e8m0,
|
||||
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
|
||||
column_major_scales=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, compute_capability=None):
|
||||
if not current_platform.is_cuda():
|
||||
return False, "DeepGEMM is only supported on cuda platform"
|
||||
if not is_deep_gemm_supported():
|
||||
return False, "Currently, only Hopper and Blackwell GPUs are supported."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config):
|
||||
can_implement_base, reason = super().can_implement(config)
|
||||
if not can_implement_base:
|
||||
return can_implement_base, reason
|
||||
if config.out_dtype != torch.bfloat16:
|
||||
return (False, "Supports only output dtype of bfloat16")
|
||||
|
||||
act_quant_desc = config.activation_quant_key.scale
|
||||
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||
return (
|
||||
False,
|
||||
"Supports only dynamic per token group activation "
|
||||
"quantization with group_shape=(1,128).",
|
||||
)
|
||||
model_config = get_current_vllm_config().model_config
|
||||
|
||||
if model_config is None:
|
||||
return False, "Model configuration is required."
|
||||
|
||||
model_type = getattr(model_config.hf_text_config, "model_type", None)
|
||||
if should_auto_disable_deep_gemm(model_type):
|
||||
return False, f"Should not use deepgemm for model {model_type}"
|
||||
|
||||
if not should_use_deepgemm_for_fp8_linear(
|
||||
config.out_dtype, config.weight_shape
|
||||
):
|
||||
return False, "The provided metadata is not supported."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super().process_weights_after_loading(layer)
|
||||
params = self._get_layer_params(layer)
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
if self.is_deep_gemm_supported:
|
||||
weight_scale_invs = params.weight_scale_inv
|
||||
scale_attr = (
|
||||
params.WEIGHT_SCALE_INV
|
||||
if weight_scale_invs is not None
|
||||
else params.WEIGHT_SCALE
|
||||
)
|
||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=params.weight,
|
||||
ws=weight_scale_invs
|
||||
if weight_scale_invs is not None
|
||||
else params.weight_scale,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
replace_parameter(layer, params.WEIGHT, dg_weight)
|
||||
replace_parameter(layer, scale_attr, dg_weight_scale)
|
||||
|
||||
def apply_block_scaled_mm(
|
||||
self,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out_dtype = self.config.out_dtype
|
||||
output = torch.empty(
|
||||
(A.shape[0], B.shape[0]),
|
||||
dtype=out_dtype,
|
||||
device=A.device,
|
||||
)
|
||||
torch.ops.vllm.fp8_gemm_nt_op(A, As, B, Bs, output, self.use_deep_gemm_e8m0)
|
||||
return output
|
||||
|
||||
|
||||
def _fp8_gemm_nt_op(
|
||||
q_input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> None:
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||
)
|
||||
|
||||
|
||||
def _fp8_gemm_nt_op_fake(
|
||||
q_input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"fp8_gemm_nt_op",
|
||||
_fp8_gemm_nt_op,
|
||||
mutates_args=["output"],
|
||||
fake_impl=_fp8_gemm_nt_op_fake,
|
||||
)
|
||||
@@ -2,11 +2,32 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_fp8_blockscale_gemm,
|
||||
flashinfer_scaled_fp8_mm,
|
||||
has_flashinfer,
|
||||
is_flashinfer_fp8_blockscale_gemm_supported,
|
||||
should_use_flashinfer_for_blockscale_fp8_gemm,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .BlockScaledMMLinearKernel import (
|
||||
Fp8BlockScaledDynamicMMLinearKernel,
|
||||
Fp8BlockScaledMMLinearKernel,
|
||||
)
|
||||
from .deep_gemm import DeepGemmFp8BlockScaledMMKernel, fp8_gemm_nt
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
@@ -55,3 +76,256 @@ class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
return flashinfer_scaled_fp8_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||
# FlashInfer accepts BF16 input and handles FP8 conversion internally.
|
||||
apply_input_quant: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
||||
can_implement_base, reason = super().can_implement(config)
|
||||
if not can_implement_base:
|
||||
return can_implement_base, reason
|
||||
|
||||
act_quant_desc = config.activation_quant_key.scale
|
||||
if act_quant_desc.group_shape != GroupShape(1, 128):
|
||||
return (
|
||||
False,
|
||||
"Supports only dynamic per token group activation "
|
||||
"quantization with group_shape=(1,128).",
|
||||
)
|
||||
|
||||
if not should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
is_flashinfer_fp8_blockscale_gemm_supported(),
|
||||
config.out_dtype,
|
||||
config.input_dtype,
|
||||
config.weight_quant_key.dtype,
|
||||
config.weight_shape,
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"The provided metadata is not supported.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, compute_capability=None):
|
||||
if not current_platform.is_cuda():
|
||||
return False, "only cuda devices are supported."
|
||||
|
||||
if not is_flashinfer_fp8_blockscale_gemm_supported():
|
||||
return False, "FlashInfer block-scale FP8 GEMM is not available."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_block_scaled_mm(
|
||||
self,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# A is BF16 — FlashInfer handles FP8 conversion internally.
|
||||
# As is a placeholder (apply_input_quant=False) and is not used here.
|
||||
return torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
|
||||
A, # BF16 input
|
||||
B, # FP8 weight
|
||||
Bs, # Weight scales
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFp8DeepGEMMDynamicBlockScaledKernel(
|
||||
Fp8BlockScaledDynamicMMLinearKernel
|
||||
):
|
||||
"""
|
||||
Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.
|
||||
|
||||
Dispatches between two kernels based on input batch size:
|
||||
- Small batches (M < 32): FlashInfer's swapAB trick for better utilisation.
|
||||
- Large batches (M >= 32): DeepGEMM for peak throughput.
|
||||
|
||||
apply_input_quant is False because FlashInfer accepts BF16 input and
|
||||
handles FP8 conversion internally. The DeepGEMM branch therefore
|
||||
quantises BF16→FP8 inside apply_mm via a closure before dispatching to
|
||||
the DeepGEMM kernel — keeping both branches compatible with the single
|
||||
BF16 tensor operand list passed by torch.cond.
|
||||
"""
|
||||
|
||||
base_type: ClassVar[type[FlashInferFp8BlockScaledMMKernel]] = (
|
||||
FlashInferFp8BlockScaledMMKernel
|
||||
)
|
||||
fallback_type: ClassVar[type[DeepGemmFp8BlockScaledMMKernel]] = (
|
||||
DeepGemmFp8BlockScaledMMKernel
|
||||
)
|
||||
apply_input_quant: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
|
||||
super().__init__(config)
|
||||
self.base: FlashInferFp8BlockScaledMMKernel
|
||||
self.fallback: DeepGemmFp8BlockScaledMMKernel
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
# DeepGEMM need post-processing; both kernels share the same
|
||||
# parameter tensor layout so processing once is sufficient.
|
||||
self.fallback.process_weights_after_loading(layer)
|
||||
|
||||
def apply_block_scaled_mm(
|
||||
self,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
group_size = self.weight_group_shape.col
|
||||
use_deep_gemm_e8m0 = self.fallback.use_deep_gemm_e8m0
|
||||
|
||||
return torch.ops.vllm.dynamic_flashinfer_deepgemm_blockscale_gemm(
|
||||
A, B, Bs, group_size, use_deep_gemm_e8m0
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_impl(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_fp8_blockscale_gemm(
|
||||
input=input,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Required fake/meta implementation for torch.compile graph tracing.
|
||||
"""
|
||||
return torch.empty(
|
||||
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"flashinfer_fp8_blockscale_gemm",
|
||||
_flashinfer_fp8_blockscale_gemm_impl,
|
||||
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_flashinfer_deepgemm_blockscale_gemm_impl(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
|
||||
|
||||
This function switches between two optimized kernels based on the input batch size:
|
||||
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
|
||||
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
|
||||
|
||||
The conditional logic must use torch.cond() instead of a simple if-else statement
|
||||
to maintain compatibility with torch.compile graph compilation.
|
||||
|
||||
This batch-size-dependent selection is essential for maintaining model accuracy.
|
||||
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
|
||||
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
|
||||
drop.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (batch_size, input_dim) in FP8 format
|
||||
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
|
||||
weight_scale: Scale factors for weight quantization (per-group)
|
||||
group_size: Quantization group size for the weight tensor
|
||||
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (batch_size, output_dim) in bfloat16 format
|
||||
"""
|
||||
|
||||
def run_flashinfer_deepgemm_swapAB(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_fp8_blockscale_gemm(
|
||||
input=input,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def run_deepgemm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input,
|
||||
group_size=group_size,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_deep_gemm_e8m0,
|
||||
)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||
)
|
||||
return output
|
||||
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
return run_deepgemm(input, weight, weight_scale)
|
||||
|
||||
condition = input.shape[0] < 32
|
||||
|
||||
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
|
||||
# Python conditionals. torch.cond() explicitly registers both code paths in the
|
||||
# computation graph, allowing torch.compile to capture both branches.
|
||||
# without torch.cond, the M < 32 condition won't be able to be captured by torch
|
||||
# compile
|
||||
return torch.cond(
|
||||
condition,
|
||||
run_flashinfer_deepgemm_swapAB,
|
||||
run_deepgemm,
|
||||
(input, weight, weight_scale),
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_flashinfer_deepgemm_blockscale_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Required fake/meta implementation for torch.compile graph tracing.
|
||||
"""
|
||||
return torch.empty(
|
||||
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"dynamic_flashinfer_deepgemm_blockscale_gemm",
|
||||
_dynamic_flashinfer_deepgemm_blockscale_gemm_impl,
|
||||
fake_impl=_dynamic_flashinfer_deepgemm_blockscale_gemm_fake,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .BlockScaledMMLinearKernel import (
|
||||
Fp8BlockScaledMMLinearKernel,
|
||||
)
|
||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
@@ -150,3 +154,67 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
out -= (x_s * w_s_row * azp_adj).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TritonFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(cls, compute_capability=None):
|
||||
if not current_platform.is_cuda_alike():
|
||||
return False, "only cuda like devices are supported."
|
||||
return True, None
|
||||
|
||||
def apply_block_scaled_mm(
|
||||
self,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||
A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
list(self.weight_group_shape),
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
def _w8a8_triton_block_scaled_mm_func(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
|
||||
return w8a8_triton_block_scaled_mm(
|
||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
|
||||
|
||||
def _w8a8_triton_block_scaled_mm_fake(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
@@ -16,18 +17,16 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_channel_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
@@ -67,6 +66,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
|
||||
@@ -75,21 +75,18 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
assert not self.is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(*self.weight_block_size)
|
||||
)
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
static=False, group_shape=self.act_q_group_shape
|
||||
)
|
||||
else:
|
||||
activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
|
||||
weight_quant_key = weight_quant_key_mapping[self.strategy]
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.activation_quant_key = activation_quant_key_mapping[
|
||||
is_static_input_scheme
|
||||
]
|
||||
self.weight_quant_key = weight_quant_key_mapping[self.strategy]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -146,6 +143,15 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
@@ -163,10 +169,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
elif self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.is_static_input_scheme is False
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale
|
||||
)
|
||||
input_scale = None
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
|
||||
layer.input_scale = None
|
||||
# fp8_linear.process_weights_after_loading applies the post process
|
||||
# and reassigns the weight and weight_scale buffers to layer attributes.
|
||||
return
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -185,8 +193,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
if hasattr(self, "fp8_linear"):
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
@@ -197,13 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
@@ -93,12 +94,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -149,6 +145,15 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
layer.input_scale_ub = input_scale_ub
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
@@ -10,7 +10,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
@@ -45,13 +45,10 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
validate_fp8_block_shape,
|
||||
@@ -61,6 +58,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
is_layer_skipped,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
@@ -273,12 +271,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
if self.quant_config.use_deep_gemm is not None:
|
||||
self.use_deep_gemm = self.quant_config.use_deep_gemm
|
||||
else:
|
||||
@@ -288,37 +287,26 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if self.act_q_static:
|
||||
activation_quant_key = kFp8StaticTensorSym
|
||||
elif cutlass_fp8_supported():
|
||||
activation_quant_key = kFp8DynamicTokenSym
|
||||
else:
|
||||
activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
if self.block_quant:
|
||||
weight_quant_key = kFp8Static128BlockSym
|
||||
else:
|
||||
weight_quant_key = kFp8StaticTensorSym
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
||||
|
||||
if self.block_quant and not self.use_marlin:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
use_deep_gemm=self.use_deep_gemm,
|
||||
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
static=self.act_q_static,
|
||||
group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
)
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(*self.weight_block_size)
|
||||
)
|
||||
else:
|
||||
self.weight_quant_key = kFp8StaticTensorSym
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if self.act_q_static:
|
||||
self.activation_quant_key = kFp8StaticTensorSym
|
||||
elif cutlass_fp8_supported():
|
||||
self.activation_quant_key = kFp8DynamicTokenSym
|
||||
else:
|
||||
self.activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -384,6 +372,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if self.use_marlin:
|
||||
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
|
||||
@@ -398,13 +397,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
|
||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale_inv
|
||||
)
|
||||
|
||||
# Update layer with new values
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
else:
|
||||
@@ -435,9 +428,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.block_quant and self.use_deep_gemm:
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -449,12 +439,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
return self.fp8_linear.apply_weights(
|
||||
layer,
|
||||
x,
|
||||
bias,
|
||||
)
|
||||
else:
|
||||
# per-tensor/channel: dequant to BF16 and run GEMM
|
||||
@@ -483,17 +471,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.use_marlin:
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
@@ -538,6 +515,24 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
|
||||
initialize_online_processing(layer)
|
||||
|
||||
# TODO: remove this check once the following RFC is resolved.
|
||||
# https://github.com/vllm-project/vllm/issues/33314
|
||||
# This check is required because Mxfp8OnlineLinearMethod inherits from
|
||||
# Fp8OnlineLinearMethod but only calls super().create_weights(), so we must
|
||||
# skip the fp8_linear kernel creation.
|
||||
if hasattr(self, "mxfp8_linear"):
|
||||
return
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
)
|
||||
@@ -78,6 +78,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
@@ -86,7 +87,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
@@ -450,12 +450,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8StaticTensorSym,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -505,6 +501,15 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8StaticTensorSym,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = layer.weight
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
@@ -536,12 +541,8 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -587,6 +588,15 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
@@ -616,12 +626,16 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
self.quant_config = quant_config
|
||||
block_n, block_k = self._WEIGHT_BLOCK_SIZE
|
||||
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(block_n, block_k),
|
||||
act_quant_group_shape=GroupShape(1, block_k),
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
static=False, group_shape=GroupShape(1, block_k)
|
||||
)
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(block_n, block_k)
|
||||
)
|
||||
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -688,8 +702,17 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
self.w8a8_block_fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
|
||||
# Keep weight in [out, in] layout for Fp8BlockScaledMMLinearKernel.
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
scale = layer.weight_scale
|
||||
@@ -713,13 +736,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
bias=bias,
|
||||
)
|
||||
return self.w8a8_block_fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -28,13 +28,9 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
create_fp8_quant_key,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
@@ -42,7 +38,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
)
|
||||
from vllm.model_executor.model_loader.reload.layerwise import (
|
||||
@@ -51,7 +46,7 @@ from vllm.model_executor.model_loader.reload.layerwise import (
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Online FP8 Linear Methods
|
||||
@@ -64,6 +59,10 @@ class _Fp8OnlineLinearBase(LinearMethodBase):
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
def __init__(self):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -103,18 +102,41 @@ class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
|
||||
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
|
||||
|
||||
def __init__(self):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
super().__init__()
|
||||
|
||||
self.weight_quant_key = kFp8StaticTensorSym
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if cutlass_fp8_supported():
|
||||
activation_quant_key = kFp8DynamicTokenSym
|
||||
self.activation_quant_key = kFp8DynamicTokenSym
|
||||
else:
|
||||
activation_quant_key = kFp8DynamicTensorSym
|
||||
self.activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
super().create_weights(
|
||||
layer,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
input_size,
|
||||
output_size,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@@ -166,19 +188,14 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
||||
Loads fp16/bf16 weights and quantizes them per-block during loading."""
|
||||
|
||||
def __init__(self):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
super().__init__()
|
||||
self.weight_block_size = [128, 128]
|
||||
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
use_deep_gemm=self.use_deep_gemm,
|
||||
self.activation_quant_key = create_fp8_quant_key(
|
||||
static=False,
|
||||
group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
)
|
||||
self.weight_quant_key = create_fp8_quant_key(
|
||||
static=True, group_shape=GroupShape(*self.weight_block_size)
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
@@ -202,6 +219,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
||||
)
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
@@ -213,14 +239,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
||||
layer.weight, block_size=block_size, use_ue8m0=False
|
||||
)
|
||||
|
||||
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
qweight, weight_scale_inv
|
||||
)
|
||||
|
||||
replace_parameter(layer, "weight", qweight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
self.fp8_linear.process_weights_after_loading(layer)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
@@ -234,12 +256,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
||||
assert self.weight_block_size is not None
|
||||
|
||||
# Note: batch invariance already handled in the function below
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
return self.fp8_linear.apply_weights(
|
||||
layer,
|
||||
x,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, cast
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
@@ -57,6 +58,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.input_dtype = get_current_vllm_config().model_config.dtype
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -175,7 +177,9 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
weight_shape=layer.weight.shape,
|
||||
input_dtype=self.input_dtype,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,15 +12,11 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
get_fp8_min_max,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
all_close_1d,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
@@ -29,22 +25,14 @@ from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_tma_aligned_size,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
transform_sf_into_required_layout,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_fp8_blockscale_gemm,
|
||||
is_flashinfer_fp8_blockscale_gemm_supported,
|
||||
should_use_flashinfer_for_blockscale_fp8_gemm,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -56,153 +44,6 @@ def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
|
||||
|
||||
# We need to pass in the is_hopper flag as argument because the function
|
||||
# current_platform.is_device_capability() is not supported by Torch compiler.
|
||||
def cutlass_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.T,
|
||||
)
|
||||
|
||||
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
def _w8a8_triton_block_scaled_mm_func(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return w8a8_triton_block_scaled_mm(
|
||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
|
||||
|
||||
def _w8a8_triton_block_scaled_mm_fake(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
def _padded_cutlass(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
pad_multiple = 4
|
||||
dim = qx.shape[0]
|
||||
padded = (
|
||||
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
|
||||
)
|
||||
|
||||
has_pad = padded > dim
|
||||
|
||||
if has_pad:
|
||||
padded_shape = [padded, *qx.shape[1:]]
|
||||
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
|
||||
padded_qx[0 : qx.shape[0], ...].copy_(qx)
|
||||
|
||||
padded_x_scale_shape = [*x_scale.shape[1:], padded]
|
||||
padded_x_scale = torch.ones(
|
||||
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
|
||||
).permute(-1, -2)
|
||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||
|
||||
output = cutlass_scaled_mm(
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
return output[0 : qx.shape[0], ...]
|
||||
else:
|
||||
return cutlass_scaled_mm(
|
||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
|
||||
|
||||
def _padded_cutlass_fake(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"padded_cutlass",
|
||||
_padded_cutlass,
|
||||
fake_impl=_padded_cutlass_fake,
|
||||
)
|
||||
|
||||
|
||||
def _fp8_gemm_nt_op(
|
||||
q_input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> None:
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||
)
|
||||
|
||||
|
||||
def _fp8_gemm_nt_op_fake(
|
||||
q_input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"fp8_gemm_nt_op",
|
||||
_fp8_gemm_nt_op,
|
||||
mutates_args=["output"],
|
||||
fake_impl=_fp8_gemm_nt_op_fake,
|
||||
)
|
||||
|
||||
|
||||
def _triton_per_token_group_quant_fp8_impl(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
@@ -236,362 +77,6 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_impl(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
|
||||
|
||||
This function switches between two optimized kernels based on the input batch size:
|
||||
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
|
||||
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
|
||||
|
||||
The conditional logic must use torch.cond() instead of a simple if-else statement
|
||||
to maintain compatibility with torch.compile graph compilation.
|
||||
|
||||
This batch-size-dependent selection is essential for maintaining model accuracy.
|
||||
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
|
||||
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
|
||||
drop.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (batch_size, input_dim) in FP8 format
|
||||
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
|
||||
weight_scale: Scale factors for weight quantization (per-group)
|
||||
group_size: Quantization group size for the weight tensor
|
||||
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (batch_size, output_dim) in bfloat16 format
|
||||
"""
|
||||
|
||||
def run_flashinfer_deepgemm_swapAB(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_fp8_blockscale_gemm(
|
||||
input=input,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def run_deepgemm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input,
|
||||
group_size=group_size,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_deep_gemm_e8m0,
|
||||
)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||
)
|
||||
return output
|
||||
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
return run_deepgemm(input, weight, weight_scale)
|
||||
|
||||
condition = input.shape[0] < 32
|
||||
|
||||
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
|
||||
# Python conditionals. torch.cond() explicitly registers both code paths in the
|
||||
# computation graph, allowing torch.compile to capture both branches.
|
||||
# without torch.cond, the M < 32 condition won't be able to be captured by torch
|
||||
# compile
|
||||
return torch.cond(
|
||||
condition,
|
||||
run_flashinfer_deepgemm_swapAB,
|
||||
run_deepgemm,
|
||||
(input, weight, weight_scale),
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Required fake/meta implementation for torch.compile graph tracing.
|
||||
"""
|
||||
return torch.empty(
|
||||
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"flashinfer_fp8_blockscale_gemm",
|
||||
_flashinfer_fp8_blockscale_gemm_impl,
|
||||
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
|
||||
)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class W8A8BlockFp8LinearOp:
|
||||
"""
|
||||
This class executes a Blocked FP8 linear layer using cutlass if supported
|
||||
and torch.scaled_mm otherwise.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_group_shape: GroupShape,
|
||||
act_quant_group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
use_deep_gemm: bool | None = None,
|
||||
):
|
||||
self.weight_group_shape = weight_group_shape
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
if use_deep_gemm is not None:
|
||||
self.is_deep_gemm_supported = use_deep_gemm
|
||||
else:
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||
# to use deepgemm because we don't know the shape of weights (and
|
||||
# whether deepgemm supports it) at the init time.
|
||||
self.w8a8_blockscale_op, self.input_quant_op = (
|
||||
self._dispatch_w8a8_blockscale_op(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported
|
||||
)
|
||||
)
|
||||
self.deepgemm_input_quant_op = (
|
||||
QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
|
||||
use_ue8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
if self.is_deep_gemm_supported
|
||||
else None
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
if should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
self.is_flashinfer_supported, output_dtype, input_2d, weight
|
||||
) and should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype, weight, self.is_deep_gemm_supported
|
||||
):
|
||||
output = self._run_flashinfer(input_2d, weight, weight_scale)
|
||||
|
||||
elif should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype, weight, self.is_deep_gemm_supported
|
||||
):
|
||||
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
||||
else:
|
||||
output = self.w8a8_blockscale_op(
|
||||
input_2d, weight, weight_scale, input_scale
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def _run_deepgemm(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
torch.ops.vllm.fp8_gemm_nt_op(
|
||||
q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
|
||||
)
|
||||
return output
|
||||
|
||||
def _run_cutlass(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
assert self.input_quant_op is not None
|
||||
q_input, input_scale = self.input_quant_op(input_2d)
|
||||
if self.is_hopper:
|
||||
return torch.ops.vllm.padded_cutlass(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype,
|
||||
)
|
||||
else:
|
||||
return cutlass_scaled_mm(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_aiter(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
|
||||
n, k = weight.shape
|
||||
|
||||
use_triton = (
|
||||
not current_platform.is_fp8_fnuz()
|
||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||
)
|
||||
|
||||
if use_triton:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||
else:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
||||
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
else:
|
||||
q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton)
|
||||
|
||||
return gemm_a8w8_blockscale_op(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
output_dtype=input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
assert self.input_quant_op is not None
|
||||
q_input, input_scale = self.input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_flashinfer(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run FlashInfer FP8 block-scale GEMM.
|
||||
|
||||
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
|
||||
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
|
||||
"""
|
||||
# Now call FlashInfer with BF16 input + FP8 weight, input will be
|
||||
# quantized with FlashInfer kernel (W8A8)
|
||||
output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
|
||||
input=input_2d, # BF16 input
|
||||
weight=weight, # FP8 weight
|
||||
weight_scale=weight_scale, # Weight scales
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_deep_gemm_e8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
return output
|
||||
|
||||
def _dispatch_w8a8_blockscale_op(
|
||||
self,
|
||||
use_cutlass: bool,
|
||||
use_aiter_and_is_supported: bool,
|
||||
) -> tuple[
|
||||
Callable[
|
||||
[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
],
|
||||
torch.Tensor,
|
||||
],
|
||||
QuantFP8,
|
||||
]:
|
||||
if use_cutlass:
|
||||
return self._run_cutlass, (
|
||||
QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
)
|
||||
if use_aiter_and_is_supported:
|
||||
return self._run_aiter, QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
return self._run_triton, (
|
||||
QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor, dtype: torch.dtype | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -1612,34 +1097,6 @@ def process_fp8_weight_block_strategy(
|
||||
return weight, weight_scale
|
||||
|
||||
|
||||
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
|
||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
|
||||
layer.orig_dtype, layer.weight
|
||||
)
|
||||
if should_use_deepgemm:
|
||||
scale_attr = (
|
||||
"weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale"
|
||||
)
|
||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.weight.data,
|
||||
ws=getattr(layer, scale_attr).data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
replace_parameter(layer, "weight", dg_weight)
|
||||
replace_parameter(layer, scale_attr, dg_weight_scale)
|
||||
|
||||
|
||||
def process_fp8_weight_tensor_strategy_moe(
|
||||
weight: torch.Tensor,
|
||||
weight_scales: torch.Tensor,
|
||||
|
||||
@@ -171,6 +171,16 @@ kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
|
||||
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
|
||||
|
||||
|
||||
def create_fp8_quant_key(
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
symmetric: bool = True,
|
||||
scale_dtype: torch.dtype = torch.float32,
|
||||
) -> QuantKey:
|
||||
scale_desc = ScaleDesc(scale_dtype, static, group_shape)
|
||||
return QuantKey(FP8_DTYPE, scale_desc, symmetric=symmetric)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
|
||||
@@ -413,7 +413,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype: torch.dtype,
|
||||
weight: torch.Tensor,
|
||||
weight_shape: tuple[int, int],
|
||||
supports_deep_gemm: bool | None = None,
|
||||
):
|
||||
if supports_deep_gemm is None:
|
||||
@@ -428,8 +428,8 @@ def should_use_deepgemm_for_fp8_linear(
|
||||
return (
|
||||
supports_deep_gemm
|
||||
and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % N_MULTIPLE == 0
|
||||
and weight.shape[1] % K_MULTIPLE == 0
|
||||
and weight_shape[0] % N_MULTIPLE == 0
|
||||
and weight_shape[1] % K_MULTIPLE == 0
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -748,8 +748,9 @@ def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
|
||||
def should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
is_flashinfer_supported: bool,
|
||||
output_dtype: torch.dtype,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_dtype: torch.dtype,
|
||||
weight_dtype: torch.dtype,
|
||||
weight_shape: tuple[int, int],
|
||||
):
|
||||
if not is_flashinfer_supported:
|
||||
return False
|
||||
@@ -760,15 +761,12 @@ def should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
N_MULTIPLE = 64
|
||||
K_MULTIPLE = 128
|
||||
|
||||
weight_dtype = weight.dtype
|
||||
input_dtype = input.dtype
|
||||
|
||||
should_use_flashinfer = (
|
||||
output_dtype == torch.bfloat16
|
||||
and input_dtype == torch.bfloat16
|
||||
and weight_dtype == torch.float8_e4m3fn
|
||||
and weight.shape[0] % N_MULTIPLE == 0
|
||||
and weight.shape[1] % K_MULTIPLE == 0
|
||||
and weight_shape[0] % N_MULTIPLE == 0
|
||||
and weight_shape[1] % K_MULTIPLE == 0
|
||||
)
|
||||
|
||||
return should_use_flashinfer
|
||||
|
||||
Reference in New Issue
Block a user