[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

This commit is contained in:
Robert Shaw
2026-01-21 08:22:33 -05:00
committed by GitHub
parent e14467be43
commit 42135d6898
82 changed files with 2710 additions and 1563 deletions

View File

@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
@@ -41,7 +41,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
@@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
def __init__(self) -> None:
super().__init__(kNvfp4Quant)
super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501

View File

@@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Quant,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
@@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Quant, dtype)
super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(

View File

@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501