diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 996da4a29..b9913734d 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -290,6 +290,9 @@ def test_rms_group_quant( # Force spawn as it is more general. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + # TODO: remove this after fusion is fixed + monkeypatch.setenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "0") + model_kwargs["attention_config"] = {"backend": backend.name} compilation_config = CompilationConfig( diff --git a/vllm/envs.py b/vllm/envs.py index e28f9c431..97351902b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -162,6 +162,7 @@ if TYPE_CHECKING: VLLM_USE_DEEP_GEMM: bool = True VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True + VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -1201,6 +1202,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) ), + # Whether to create TMA-aligned scale tensor when DeepGEMM is used. + "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( + int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 638f0f71a..3f1b14901 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -379,7 +379,7 @@ class W8A8BlockFp8LinearOp: False, self.act_quant_group_shape, column_major_scales=True, - tma_aligned_scales=True, + tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES, use_ue8m0=self.use_deep_gemm_e8m0, ) if self.is_deep_gemm_supported