[BugFix] Fix cache issue in compilation_config (#31376)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
@@ -428,3 +428,45 @@ def test_cudagraph_sizes_post_init(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
|
||||
def test_cached_compilation_config():
|
||||
import torch
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
from vllm.config import get_cached_compilation_config, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda:0")
|
||||
batch_size, num_qo_heads, head_size = 8, 16, 128
|
||||
|
||||
# access and cache default compilation config
|
||||
# default compilation config does not contain +quant_fp8 custom op. If this is
|
||||
# used, the generated code would use inductor-generated triton kernel instead
|
||||
# of the custom op `torch.ops._C.static_scaled_fp8_quant`.
|
||||
get_cached_compilation_config()
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
)
|
||||
|
||||
# set_current_vllm_config should clear cached compilation config and
|
||||
# use the new compilation_config in vllm_config
|
||||
with set_current_vllm_config(vllm_config):
|
||||
query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||
query_quant = torch.compile(query_quant)
|
||||
|
||||
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
query = torch.randn(
|
||||
batch_size, num_qo_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
_, code = run_and_get_code(query_quant, query, _q_scale)
|
||||
|
||||
code = " ".join(code)
|
||||
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
|
||||
|
||||
Reference in New Issue
Block a user