[3/N][torch.compile] consolidate custom op logging (#10399)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -4,8 +4,9 @@ import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
|
||||
Literal, Mapping, Optional, Set, Tuple, Type, Union)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
|
||||
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -2169,6 +2170,10 @@ class CompilationConfig(BaseModel):
|
||||
compile_sizes: List[int] = PrivateAttr
|
||||
capture_sizes: List[int] = PrivateAttr
|
||||
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
|
||||
|
||||
@@ -2190,6 +2195,9 @@ class CompilationConfig(BaseModel):
|
||||
func = __import__(module).__dict__[func_name]
|
||||
self.inductor_compile_config[k] = func
|
||||
|
||||
self.enabled_custom_ops = Counter()
|
||||
self.disabled_custom_ops = Counter()
|
||||
|
||||
def init_backend(self) -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
|
||||
Reference in New Issue
Block a user