diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 5f9d4ac53..1d4647651 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -428,3 +428,45 @@ def test_cudagraph_sizes_post_init( vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size ) + + +def test_cached_compilation_config(): + import torch + from torch._inductor.utils import run_and_get_code + + from vllm.config import get_cached_compilation_config, set_current_vllm_config + from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 + from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + + dtype = torch.bfloat16 + device = torch.device("cuda:0") + batch_size, num_qo_heads, head_size = 8, 16, 128 + + # access and cache default compilation config + # default compilation config does not contain +quant_fp8 custom op. If this is + # used, the generated code would use inductor-generated triton kernel instead + # of the custom op `torch.ops._C.static_scaled_fp8_quant`. + get_cached_compilation_config() + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+quant_fp8"], + ) + ) + + # set_current_vllm_config should clear cached compilation config and + # use the new compilation_config in vllm_config + with set_current_vllm_config(vllm_config): + query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + query_quant = torch.compile(query_quant) + + _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") + query = torch.randn( + batch_size, num_qo_heads * head_size, dtype=dtype, device=device + ) + + _, code = run_and_get_code(query_quant, query, _q_scale) + + code = " ".join(code) + assert "torch.ops._C.static_scaled_fp8_quant.default(" in code diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0439dc52e..70319f98f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1360,6 +1360,11 @@ def set_current_vllm_config( num_models_seen = compilation_counter.num_models_seen try: + # Clear the compilation config cache when context changes. + # This is needed since the old config may have been accessed + # and cached before the new config is set. + get_cached_compilation_config.cache_clear() + _current_vllm_config = vllm_config _current_prefix = prefix yield