Signed-off-by: khluu <khluu000@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-03-26 12:21:47 +04:00
committed by khluu
parent ccbc5ac449
commit 05d96d7991
10 changed files with 73 additions and 10 deletions

View File

@@ -1,5 +1,6 @@
model_name: "Qwen/Qwen3.5-35B-A3B" model_name: "Qwen/Qwen3.5-35B-A3B"
accuracy_threshold: 0.86 accuracy_threshold: 0.84
tolerance: 0.03
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: >- server_args: >-

View File

@@ -1,5 +1,6 @@
model_name: "Qwen/Qwen3.5-35B-A3B-FP8" model_name: "Qwen/Qwen3.5-35B-A3B-FP8"
accuracy_threshold: 0.86 accuracy_threshold: 0.79
tolerance: 0.03
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
server_args: >- server_args: >-

View File

@@ -0,0 +1,9 @@
model_name: "nvidia/Qwen3.5-397B-A17B-NVFP4"
accuracy_threshold: 0.88
tolerance: 0.03
num_questions: 1319
num_fewshot: 5
server_args: >-
--max-model-len 4096
--data-parallel-size 2
--enable-expert-parallel

View File

@@ -1,2 +1,3 @@
Qwen3.5-35B-A3B-DEP2.yaml Qwen3.5-35B-A3B-DEP2.yaml
Qwen3.5-35B-A3B-FP8-DEP2.yaml Qwen3.5-35B-A3B-FP8-DEP2.yaml
Qwen3.5-397B-A17B-NVFP4-DEP2.yaml

View File

@@ -19,8 +19,6 @@ from vllm.platforms import current_platform
from .gsm8k_eval import evaluate_gsm8k from .gsm8k_eval import evaluate_gsm8k
TOL = 0.08 # Absolute tolerance for accuracy comparison
def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict: def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
"""Run GSM8K evaluation using our isolated script.""" """Run GSM8K evaluation using our isolated script."""
@@ -99,20 +97,20 @@ def test_gsm8k_correctness(config_filename):
measured_metric = results["accuracy"] measured_metric = results["accuracy"]
expected_metric = eval_config["accuracy_threshold"] expected_metric = eval_config["accuracy_threshold"]
tol = eval_config.get("tolerance", 0.08)
print(f"GSM8K Results for {eval_config['model_name']}:") print(f"GSM8K Results for {eval_config['model_name']}:")
print(f" Measured metric: {measured_metric:.4f}") print(f" Measured metric: {measured_metric:.4f}")
print(f" Expected metric: {expected_metric:.4f}") print(f" Expected metric: {expected_metric:.4f}")
print(f" Tolerance: {TOL:.4f}") print(f" Tolerance: {tol:.4f}")
print(f" Questions: {results['num_questions']}") print(f" Questions: {results['num_questions']}")
print(f" Invalid rate: {results['invalid_rate']:.3f}") print(f" Invalid rate: {results['invalid_rate']:.3f}")
print(f" Latency: {results['latency']:.1f}s") print(f" Latency: {results['latency']:.1f}s")
print(f" QPS: {results['questions_per_second']:.1f}") print(f" QPS: {results['questions_per_second']:.1f}")
# Verify metric is within tolerance assert measured_metric >= expected_metric - tol, (
assert measured_metric >= expected_metric - TOL, (
f"GSM8K metric too low: {measured_metric:.4f} < " f"GSM8K metric too low: {measured_metric:.4f} < "
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}" f"{expected_metric:.4f} - {tol:.4f} = {expected_metric - tol:.4f}"
) )
print(f"✅ GSM8K test passed for {eval_config['model_name']}") print(f"✅ GSM8K test passed for {eval_config['model_name']}")

View File

@@ -682,6 +682,27 @@ class VllmConfig:
self.model_config, self.load_config self.model_config, self.load_config
) )
if (
self.quant_config is not None
and self.model_config is not None
and hasattr(self.quant_config, "use_deep_gemm")
and self.quant_config.use_deep_gemm is None
):
from vllm.utils.deep_gemm import should_auto_disable_deep_gemm
model_type = getattr(self.model_config.hf_text_config, "model_type", None)
if should_auto_disable_deep_gemm(model_type):
self.quant_config.use_deep_gemm = False
logger.warning_once(
"Auto-disabled DeepGemm for model_type=%s on Blackwell. "
"DeepGemm E8M0 scale format causes accuracy degradation "
"for this architecture. Falling back to CUTLASS. "
"To disable DeepGemm globally, set VLLM_USE_DEEP_GEMM=0.",
model_type,
)
from vllm.v1.executor.abstract import Executor
executor_backend = self.parallel_config.distributed_executor_backend executor_backend = self.parallel_config.distributed_executor_backend
executor_supports_async_sched = executor_backend in ( executor_supports_async_sched = executor_backend in (
"mp", "mp",

View File

@@ -135,6 +135,7 @@ class Fp8Config(QuantizationConfig):
f"{activation_scheme} activation scheme." f"{activation_scheme} activation scheme."
) )
self.weight_block_size = weight_block_size self.weight_block_size = weight_block_size
self.use_deep_gemm: bool | None = None
@classmethod @classmethod
def get_name(cls) -> QuantizationMethods: def get_name(cls) -> QuantizationMethods:
@@ -291,6 +292,9 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = False self.use_marlin = False
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
if self.quant_config.use_deep_gemm is not None:
self.use_deep_gemm = self.quant_config.use_deep_gemm
else:
self.use_deep_gemm = is_deep_gemm_supported() self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size self.weight_block_size = self.quant_config.weight_block_size
@@ -305,6 +309,7 @@ class Fp8LinearMethod(LinearMethodBase):
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
) )
else: else:
# Use per-token quantization for better perf if dynamic and cutlass # Use per-token quantization for better perf if dynamic and cutlass
@@ -432,6 +437,7 @@ class Fp8LinearMethod(LinearMethodBase):
else: else:
layer.input_scale = None layer.input_scale = None
<<<<<<< HEAD
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin( prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype layer, size_k_first, input_dtype=self.marlin_input_dtype
@@ -441,6 +447,9 @@ class Fp8LinearMethod(LinearMethodBase):
return return
if self.block_quant: if self.block_quant:
=======
if self.block_quant and self.use_deep_gemm:
>>>>>>> 52069012f ([Bugfix] Fix DeepGemm E8M0 accuracy degradation for Qwen3.5 FP8 on Blackwell (#38083))
maybe_post_process_fp8_weight_block(layer) maybe_post_process_fp8_weight_block(layer)
def apply( def apply(

View File

@@ -91,6 +91,7 @@ class QuantFP8(CustomOp):
if ( if (
self.is_group_quant self.is_group_quant
and self.use_ue8m0
and self.use_deep_gemm_supported and self.use_deep_gemm_supported
and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0) and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0)
): ):

View File

@@ -356,9 +356,13 @@ class W8A8BlockFp8LinearOp:
act_quant_group_shape: GroupShape, act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False, use_aiter_and_is_supported: bool = False,
use_deep_gemm: bool | None = None,
): ):
self.weight_group_shape = weight_group_shape self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape self.act_quant_group_shape = act_quant_group_shape
if use_deep_gemm is not None:
self.is_deep_gemm_supported = use_deep_gemm
else:
self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90) self.is_hopper = current_platform.is_device_capability(90)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()

View File

@@ -23,6 +23,24 @@ from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
_DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES: set[str] = {
"qwen3_5_text",
"qwen3_5_moe_text",
}
def should_auto_disable_deep_gemm(model_type: str | None) -> bool:
"""Check if DeepGemm should be auto-disabled for this model on Blackwell.
Returns True if the model is known to have accuracy degradation with
DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+).
"""
if model_type is None:
return False
if not current_platform.is_device_capability_family(100):
return False
return model_type in _DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES
class DeepGemmQuantScaleFMT(Enum): class DeepGemmQuantScaleFMT(Enum):
# Float32 scales in Float32 tensor # Float32 scales in Float32 tensor