From 05d96d7991cd3540989fa6038c7704bb54e1d310 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Thu, 26 Mar 2026 12:21:47 +0400 Subject: [PATCH] merge Signed-off-by: khluu --- .../gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml | 3 ++- .../configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml | 3 ++- .../configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml | 9 ++++++++ .../gsm8k/configs/models-qwen35-blackwell.txt | 1 + tests/evals/gsm8k/test_gsm8k_correctness.py | 10 ++++----- vllm/config/vllm.py | 21 +++++++++++++++++++ .../model_executor/layers/quantization/fp8.py | 11 +++++++++- .../layers/quantization/input_quant_fp8.py | 1 + .../layers/quantization/utils/fp8_utils.py | 6 +++++- vllm/utils/deep_gemm.py | 18 ++++++++++++++++ 10 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 tests/evals/gsm8k/configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml diff --git a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml index 62be504e2..55a134ad9 100644 --- a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml +++ b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml @@ -1,5 +1,6 @@ model_name: "Qwen/Qwen3.5-35B-A3B" -accuracy_threshold: 0.86 +accuracy_threshold: 0.84 +tolerance: 0.03 num_questions: 1319 num_fewshot: 5 server_args: >- diff --git a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml index 9380e0b25..7a36052e3 100644 --- a/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml +++ b/tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml @@ -1,5 +1,6 @@ model_name: "Qwen/Qwen3.5-35B-A3B-FP8" -accuracy_threshold: 0.86 +accuracy_threshold: 0.79 +tolerance: 0.03 num_questions: 1319 num_fewshot: 5 server_args: >- diff --git a/tests/evals/gsm8k/configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml b/tests/evals/gsm8k/configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml new file mode 100644 index 000000000..cd35790c3 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3.5-397B-A17B-NVFP4-DEP2.yaml @@ -0,0 +1,9 @@ +model_name: "nvidia/Qwen3.5-397B-A17B-NVFP4" +accuracy_threshold: 0.88 +tolerance: 0.03 +num_questions: 1319 +num_fewshot: 5 +server_args: >- + --max-model-len 4096 + --data-parallel-size 2 + --enable-expert-parallel diff --git a/tests/evals/gsm8k/configs/models-qwen35-blackwell.txt b/tests/evals/gsm8k/configs/models-qwen35-blackwell.txt index 774ae8eb7..908ada3a2 100644 --- a/tests/evals/gsm8k/configs/models-qwen35-blackwell.txt +++ b/tests/evals/gsm8k/configs/models-qwen35-blackwell.txt @@ -1,2 +1,3 @@ Qwen3.5-35B-A3B-DEP2.yaml Qwen3.5-35B-A3B-FP8-DEP2.yaml +Qwen3.5-397B-A17B-NVFP4-DEP2.yaml \ No newline at end of file diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index c8028c0b8..5025c46eb 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -19,8 +19,6 @@ from vllm.platforms import current_platform from .gsm8k_eval import evaluate_gsm8k -TOL = 0.08 # Absolute tolerance for accuracy comparison - def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict: """Run GSM8K evaluation using our isolated script.""" @@ -99,20 +97,20 @@ def test_gsm8k_correctness(config_filename): measured_metric = results["accuracy"] expected_metric = eval_config["accuracy_threshold"] + tol = eval_config.get("tolerance", 0.08) print(f"GSM8K Results for {eval_config['model_name']}:") print(f" Measured metric: {measured_metric:.4f}") print(f" Expected metric: {expected_metric:.4f}") - print(f" Tolerance: {TOL:.4f}") + print(f" Tolerance: {tol:.4f}") print(f" Questions: {results['num_questions']}") print(f" Invalid rate: {results['invalid_rate']:.3f}") print(f" Latency: {results['latency']:.1f}s") print(f" QPS: {results['questions_per_second']:.1f}") - # Verify metric is within tolerance - assert measured_metric >= expected_metric - TOL, ( + assert measured_metric >= expected_metric - tol, ( f"GSM8K metric too low: {measured_metric:.4f} < " - f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}" + f"{expected_metric:.4f} - {tol:.4f} = {expected_metric - tol:.4f}" ) print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 8cd114481..a1e8e1328 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -682,6 +682,27 @@ class VllmConfig: self.model_config, self.load_config ) + if ( + self.quant_config is not None + and self.model_config is not None + and hasattr(self.quant_config, "use_deep_gemm") + and self.quant_config.use_deep_gemm is None + ): + from vllm.utils.deep_gemm import should_auto_disable_deep_gemm + + model_type = getattr(self.model_config.hf_text_config, "model_type", None) + if should_auto_disable_deep_gemm(model_type): + self.quant_config.use_deep_gemm = False + logger.warning_once( + "Auto-disabled DeepGemm for model_type=%s on Blackwell. " + "DeepGemm E8M0 scale format causes accuracy degradation " + "for this architecture. Falling back to CUTLASS. " + "To disable DeepGemm globally, set VLLM_USE_DEEP_GEMM=0.", + model_type, + ) + + from vllm.v1.executor.abstract import Executor + executor_backend = self.parallel_config.distributed_executor_backend executor_supports_async_sched = executor_backend in ( "mp", diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5101347cd..8c901278f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -135,6 +135,7 @@ class Fp8Config(QuantizationConfig): f"{activation_scheme} activation scheme." ) self.weight_block_size = weight_block_size + self.use_deep_gemm: bool | None = None @classmethod def get_name(cls) -> QuantizationMethods: @@ -291,7 +292,10 @@ class Fp8LinearMethod(LinearMethodBase): self.use_marlin = False self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() - self.use_deep_gemm = is_deep_gemm_supported() + if self.quant_config.use_deep_gemm is not None: + self.use_deep_gemm = self.quant_config.use_deep_gemm + else: + self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None @@ -305,6 +309,7 @@ class Fp8LinearMethod(LinearMethodBase): act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, + use_deep_gemm=self.use_deep_gemm, ) else: # Use per-token quantization for better perf if dynamic and cutlass @@ -432,6 +437,7 @@ class Fp8LinearMethod(LinearMethodBase): else: layer.input_scale = None +<<<<<<< HEAD if self.use_marlin: prepare_fp8_layer_for_marlin( layer, size_k_first, input_dtype=self.marlin_input_dtype @@ -441,6 +447,9 @@ class Fp8LinearMethod(LinearMethodBase): return if self.block_quant: +======= + if self.block_quant and self.use_deep_gemm: +>>>>>>> 52069012f ([Bugfix] Fix DeepGemm E8M0 accuracy degradation for Qwen3.5 FP8 on Blackwell (#38083)) maybe_post_process_fp8_weight_block(layer) def apply( diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 6fa85436d..5d4e54490 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -91,6 +91,7 @@ class QuantFP8(CustomOp): if ( self.is_group_quant + and self.use_ue8m0 and self.use_deep_gemm_supported and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0) ): diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 78b123402..2b072368f 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -356,10 +356,14 @@ class W8A8BlockFp8LinearOp: act_quant_group_shape: GroupShape, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, + use_deep_gemm: bool | None = None, ): self.weight_group_shape = weight_group_shape self.act_quant_group_shape = act_quant_group_shape - self.is_deep_gemm_supported = is_deep_gemm_supported() + if use_deep_gemm is not None: + self.is_deep_gemm_supported = use_deep_gemm + else: + self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported() diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index ee104a6cc..8fcb1f321 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -23,6 +23,24 @@ from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_gemm from vllm.utils.math_utils import cdiv +_DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES: set[str] = { + "qwen3_5_text", + "qwen3_5_moe_text", +} + + +def should_auto_disable_deep_gemm(model_type: str | None) -> bool: + """Check if DeepGemm should be auto-disabled for this model on Blackwell. + + Returns True if the model is known to have accuracy degradation with + DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+). + """ + if model_type is None: + return False + if not current_platform.is_device_capability_family(100): + return False + return model_type in _DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES + class DeepGemmQuantScaleFMT(Enum): # Float32 scales in Float32 tensor