[BugFix] Fix cache issue in compilation_config (#31376)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng
2025-12-27 06:30:39 -08:00
committed by GitHub
parent 40a8756224
commit 2f12cd32c0
2 changed files with 47 additions and 0 deletions

View File

@@ -428,3 +428,45 @@ def test_cudagraph_sizes_post_init(
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)
def test_cached_compilation_config():
import torch
from torch._inductor.utils import run_and_get_code
from vllm.config import get_cached_compilation_config, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
dtype = torch.bfloat16
device = torch.device("cuda:0")
batch_size, num_qo_heads, head_size = 8, 16, 128
# access and cache default compilation config
# default compilation config does not contain +quant_fp8 custom op. If this is
# used, the generated code would use inductor-generated triton kernel instead
# of the custom op `torch.ops._C.static_scaled_fp8_quant`.
get_cached_compilation_config()
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+quant_fp8"],
)
)
# set_current_vllm_config should clear cached compilation config and
# use the new compilation_config in vllm_config
with set_current_vllm_config(vllm_config):
query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
query_quant = torch.compile(query_quant)
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
query = torch.randn(
batch_size, num_qo_heads * head_size, dtype=dtype, device=device
)
_, code = run_and_get_code(query_quant, query, _q_scale)
code = " ".join(code)
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code