diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index b7401e644..d450e81a8 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -14,9 +14,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, ) +from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float] -QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] +QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()] VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] # Avoid combinatorial explosion with full Cartesian product NUM_TOKENS_HIDDEN_SIZES = [ @@ -61,14 +62,14 @@ def ref_dynamic_per_token_or_block_quant( group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if scale_ub is not None: - assert quant_dtype == torch.float8_e4m3fn + assert quant_dtype == current_platform.fp8_dtype() # Norm torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual) # Quant if group_size is not None: - if quant_dtype == torch.float8_e4m3fn: + if quant_dtype == current_platform.fp8_dtype(): torch_out, scales = per_token_group_quant_fp8( torch_out, group_size=group_size[1], use_ue8m0=False ) @@ -78,7 +79,7 @@ def ref_dynamic_per_token_or_block_quant( torch_out, group_size=group_size[1] ) else: - if quant_dtype == torch.float8_e4m3fn: + if quant_dtype == current_platform.fp8_dtype(): torch_out, scales = ops.scaled_fp8_quant( torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True ) @@ -162,6 +163,7 @@ def test_rms_norm( if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.set_default_device(device) + torch.cuda.set_device(device) if group_size is not None and hidden_size % group_size[1] != 0: # skip @@ -171,7 +173,7 @@ def test_rms_norm( # blockwise baseline doesn't support scale_ub return - if has_scale_ub and quant_dtype != torch.float8_e4m3fn: + if has_scale_ub and quant_dtype != current_platform.fp8_dtype(): # skip return